import os
import json
from typing import List, Dict
import sys
import traceback
import re
import time
import asyncio
import aiofiles
import aiofiles.os
from mindroot.lib.utils.debug import debug_box
[docs]
class ChatLog:
def __init__(self, log_id=0, agent=None, parent_log_id=None, context_length: int = 4096, user: str = None):
self.log_id = log_id
self.messages = []
self.parent_log_id = parent_log_id
self.agent = agent
if user is None or user == '' or user == 'None':
raise ValueError('User must be provided')
# make sure user is string
if not isinstance(user, str):
# does it have a username?
if hasattr(user, 'username'):
user = user.username
else:
# throw an error
raise ValueError('ChatLog(): user must be a string or have username field')
self.user = user
if agent is None or agent == '':
raise ValueError('Agent must be provided')
self.context_length = context_length
self.log_dir = os.environ.get('CHATLOG_DIR', 'data/chat')
self.log_dir = os.path.join(self.log_dir, self.user)
self.log_dir = os.path.join(self.log_dir, self.agent)
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
# For backward compatibility, we'll load synchronously in constructor
# but provide async methods for new code
self._load_log_sync()
def _get_log_data(self) -> Dict[str, any]:
return {
'agent': self.agent,
'log_id': self.log_id,
'messages': self.messages,
'parent_log_id': self.parent_log_id
}
def _calculate_message_length(self, message: Dict[str, str]) -> int:
return len(json.dumps(message)) // 3
def _load_log_sync(self, log_id=None) -> None:
"""Synchronous version for backward compatibility"""
if log_id is None:
log_id = self.log_id
self.log_id = log_id
log_file = os.path.join(self.log_dir, f'chatlog_{log_id}.json')
if os.path.exists(log_file):
with open(log_file, 'r') as f:
log_data = json.load(f)
self.agent = log_data.get('agent')
self.messages = log_data.get('messages', [])
self.parent_log_id = log_data.get('parent_log_id', None)
print("Loaded log file at ", log_file)
print("Message length: ", len(self.messages))
else:
print("Could not find log file at ", log_file)
self.messages = []
def _save_log_sync(self) -> None:
"""Synchronous version for backward compatibility"""
log_file = os.path.join(self.log_dir, f'chatlog_{self.log_id}.json')
with open(log_file, 'w') as f:
json.dump(self._get_log_data(), f, indent=2)
[docs]
def add_message(self, message: Dict[str, str]) -> None:
"""Synchronous version for backward compatibility"""
should_save = self._add_message_impl(message)
if should_save:
self._save_log_sync()
else:
# Handle the image case that returned False - save synchronously
if (len(self.messages) > 0 and
isinstance(self.messages[-1].get('content'), list) and
len(self.messages[-1]['content']) > 0 and
self.messages[-1]['content'][0].get('type') == 'image'):
self._save_log_sync()
def _add_message_impl(self, message: Dict[str, str]) -> None:
"""Internal implementation shared by sync and async versions"""
if len(self.messages)>0 and self.messages[-1]['role'] == message['role']:
print("found repeat role")
# check if messasge is str
# if so, convert to dict with type 'text':
if type(message['content']) == str:
message['content'] = [{'type':'text', 'text': message['content']}]
elif type(message['content']) == list:
for part in message['content']:
if part['type'] == 'image':
print("found image")
self.messages.append(message)
return False # Indicate caller should NOT save (we'll handle it)
try:
cmd_list = json.loads(self.messages[-1]['content'][0]['text'])
if type(cmd_list) != list:
debug_box("1")
cmd_list = [cmd_list]
new_json = json.loads(message['content'][0]['text'])
if type(new_json) != list:
debug_box("2")
new_json = [new_json]
new_cmd_list = cmd_list + new_json
debug_box("3")
self.messages[-1]['content'] = [{ 'type': 'text', 'text': json.dumps(new_cmd_list) }]
except Exception as e:
# assume previous mesage was not a command, was a string
debug_box("4")
print("Could not combine commands, probably normal if user message and previous system output, assuming string", e)
if type(self.messages[-1]['content']) == str:
new_msg_text = self.messages[-1]['content'] + message['content'][0]['text']
else:
new_msg_text = self.messages[-1]['content'][0]['text'] + message['content'][0]['text']
self.messages.append({'role': message['role'], 'content': [{'type': 'text', 'text': new_msg_text}]})
#print('could not combine commands. probably normal if user message and previous system output', e)
#print(self.messages[-1])
#print(message)
#raise e
else:
if len(self.messages)>0:
print('roles do not repeat, last message role is ', self.messages[-1]['role'], 'new message role is ', message['role'])
debug_box("5")
self.messages.append(message)
self._save_log_sync()
[docs]
async def add_message_async(self, message: Dict[str, str]) -> None:
"""Async version for new code that needs non-blocking operations"""
should_save = self._add_message_impl(message)
if should_save:
await self.save_log()
else:
# Handle the image case that returned False - save asynchronously
if (len(self.messages) > 0 and
isinstance(self.messages[-1].get('content'), list) and
len(self.messages[-1]['content']) > 0 and
self.messages[-1]['content'][0].get('type') == 'image'):
await self.save_log()
[docs]
def get_history(self) -> List[Dict[str, str]]:
return self.messages
[docs]
def get_recent(self, max_tokens: int = 4096) -> List[Dict[str, str]]:
recent_messages = []
total_length = 0
#print('returning all messages', self.messages)
json_messages = json.dumps(self.messages)
return json.loads(json_messages)
#for message in self.messages:
# message_length = self._calculate_message_length(message)
# if total_length + message_length <= max_tokens:
# recent_messages.append(message)
# total_length += message_length
# else:
# break
#
#return recent_messages
[docs]
async def save_log(self) -> None:
log_file = os.path.join(self.log_dir, f'chatlog_{self.log_id}.json')
async with aiofiles.open(log_file, 'w') as f:
await f.write(json.dumps(self._get_log_data(), indent=2))
[docs]
async def load_log(self, log_id = None) -> None:
if log_id is None:
log_id = self.log_id
self.log_id = log_id
log_file = os.path.join(self.log_dir, f'chatlog_{log_id}.json')
if await aiofiles.os.path.exists(log_file):
async with aiofiles.open(log_file, 'r') as f:
content = await f.read()
log_data = json.loads(content)
self.agent = log_data.get('agent')
self.messages = log_data.get('messages', [])
self.parent_log_id = log_data.get('parent_log_id', None)
print("Loaded log file at ", log_file)
print("Message length: ", len(self.messages))
else:
print("Could not find log file at ", log_file)
self.messages = []
[docs]
def count_tokens(self) -> Dict[str, int]:
"""
Count tokens in the chat log, providing both sequence totals and cumulative request totals.
Returns:
Dict with the following keys:
- input_tokens_sequence: Total tokens in all user messages
- output_tokens_sequence: Total tokens in all assistant messages
- input_tokens_total: Cumulative tokens sent to LLM across all requests
"""
# Initialize counters
input_tokens_sequence = 0 # Total tokens in all user messages
output_tokens_sequence = 0 # Total tokens in all assistant messages
input_tokens_total = 0 # Cumulative tokens sent to LLM across all requests
# Process each message
for i, message in enumerate(self.messages):
# Calculate tokens in this message (rough approximation)
message_tokens = len(json.dumps(message)) // 4
# Add to appropriate sequence counter
if message['role'] == 'assistant':
output_tokens_sequence += message_tokens
else: # user or system
input_tokens_sequence += message_tokens
# For each assistant message, calculate the input tokens for that request
# (which includes all previous messages)
if message['role'] == 'assistant':
request_input_tokens = 0
for j in range(i):
request_input_tokens += len(json.dumps(self.messages[j])) // 4
input_tokens_total += request_input_tokens
return {
'input_tokens_sequence': input_tokens_sequence,
'output_tokens_sequence': output_tokens_sequence,
'input_tokens_total': input_tokens_total
}
[docs]
async def find_chatlog_file(log_id: str) -> str:
"""
Find a chatlog file by its log_id.
Args:
log_id: The log ID to search for
Returns:
The full path to the chatlog file if found, None otherwise
"""
chat_dir = os.environ.get('CHATLOG_DIR', 'data/chat')
# Use os.walk to search through all subdirectories
for root, dirs, files in await asyncio.to_thread(os.walk, chat_dir):
for file in files:
if file == f"chatlog_{log_id}.json":
return os.path.join(root, file)
return None
[docs]
async def find_child_logs_by_parent_id(parent_log_id: str) -> List[str]:
"""
Find all chat logs that have the given parent_log_id.
Args:
parent_log_id: The parent log ID to search for
Returns:
List of log IDs that have this parent_log_id
"""
child_log_ids = []
chat_dir = os.environ.get('CHATLOG_DIR', 'data/chat')
# Search through all chatlog files
for root, dirs, files in await asyncio.to_thread(os.walk, chat_dir):
for file in files:
if file.startswith("chatlog_") and file.endswith(".json"):
try:
async with aiofiles.open(os.path.join(root, file), 'r') as f:
content = await f.read()
log_data = json.loads(content)
if log_data.get('parent_log_id') == parent_log_id:
# Extract log_id from the data
child_log_ids.append(log_data.get('log_id'))
except (json.JSONDecodeError, IOError):
continue
return child_log_ids
[docs]
async def get_cache_dir() -> str:
"""
Get the directory for token count cache files.
Creates the directory if it doesn't exist.
"""
cache_dir = os.environ.get('TOKEN_CACHE_DIR', 'data/token_cache')
if not await aiofiles.os.path.exists(cache_dir):
await aiofiles.os.makedirs(cache_dir)
return cache_dir
[docs]
async def get_cache_path(log_id: str) -> str:
"""
Get the path to the cache file for a specific log_id.
"""
cache_dir = await get_cache_dir()
return os.path.join(cache_dir, f"tokens_{log_id}.json")
[docs]
async def get_cached_token_counts(log_id: str, log_path: str) -> Dict[str, int]:
"""
Get cached token counts if available and valid.
Args:
log_id: The log ID
log_path: Path to the actual log file
Returns:
Cached token counts if valid, None otherwise
"""
cache_path = await get_cache_path(log_id)
# If cache doesn't exist, return None
if not await aiofiles.os.path.exists(cache_path):
return None
try:
# Get modification times
log_mtime = await aiofiles.os.path.getmtime(log_path)
cache_mtime = await aiofiles.os.path.getmtime(cache_path)
current_time = time.time()
# If log was modified after cache was created, cache is invalid
if log_mtime > cache_mtime:
return None
# Don't recalculate sooner than 3 minutes after last calculation
if current_time - cache_mtime < 180: # 3 minutes in seconds
async with aiofiles.open(cache_path, 'r') as f:
content = await f.read()
return json.loads(content)
# For logs that haven't been modified in over an hour, consider them "finished"
# and use the cache regardless of when it was last calculated
if current_time - log_mtime > 3600: # 1 hour in seconds
async with aiofiles.open(cache_path, 'r') as f:
content = await f.read()
return json.loads(content)
except (json.JSONDecodeError, IOError) as e:
print(f"Error reading token cache: {e}")
return None
[docs]
async def save_token_counts_to_cache(log_id: str, token_counts: Dict[str, int]) -> None:
"""
Save token counts to cache.
"""
cache_path = await get_cache_path(log_id)
async with aiofiles.open(cache_path, 'w') as f:
await f.write(json.dumps(token_counts))
[docs]
async def build_token_hierarchy(log_id: str, user: str = None, visited: set = None) -> Dict:
"""
Build a hierarchical token count structure for a log and its children.
Args:
log_id: The log ID to build hierarchy for
user: User for the log
visited: Set of already visited log IDs to prevent infinite recursion
Returns:
Dictionary with hierarchical structure containing:
- log_id: The log ID
- individual_counts: Token counts for this log only
- cumulative_counts: Token counts including all children
- children: List of child hierarchies
"""
if visited is None:
visited = set()
if log_id in visited:
return None # Prevent infinite recursion
visited.add(log_id)
# Find the chatlog file
chatlog_path = await find_chatlog_file(log_id)
if not chatlog_path:
return None
# Load the chat log
async with aiofiles.open(chatlog_path, 'r') as f:
content = await f.read()
log_data = json.loads(content)
# Check if we have cached individual counts for this specific session
cached_individual = await get_cached_token_counts(log_id, chatlog_path)
if cached_individual and 'input_tokens_sequence' in cached_individual:
print(f"Using cached individual token counts for session {log_id}")
individual_counts = {
'input_tokens_sequence': cached_individual['input_tokens_sequence'],
'output_tokens_sequence': cached_individual['output_tokens_sequence'],
'input_tokens_total': cached_individual['input_tokens_total']
}
else:
# Calculate individual counts for this session
if user is None:
try:
path_parts = chatlog_path.split(os.sep)
if len(path_parts) >= 4 and path_parts[-4] == 'chat':
user = path_parts[-3]
else:
user = "system"
except Exception:
user = "system"
temp_log = ChatLog(log_id=log_id, user=user, agent=log_data.get('agent', 'unknown'))
temp_log.messages = log_data.get('messages', [])
# Count tokens for this log only
individual_counts = temp_log.count_tokens()
# Cache the individual session counts
individual_cache_data = {
'input_tokens_sequence': individual_counts['input_tokens_sequence'],
'output_tokens_sequence': individual_counts['output_tokens_sequence'],
'input_tokens_total': individual_counts['input_tokens_total']
}
await save_token_counts_to_cache(log_id, individual_cache_data)
print(f"Cached individual token counts for session {log_id}")
# Find all child log IDs
if user is None:
try:
path_parts = chatlog_path.split(os.sep)
if len(path_parts) >= 4 and path_parts[-4] == 'chat':
user = path_parts[-3]
else:
user = "system"
except Exception:
user = "system"
temp_log = ChatLog(log_id=log_id, user=user, agent=log_data.get('agent', 'unknown'))
temp_log.messages = log_data.get('messages', [])
delegated_log_ids = extract_delegate_task_log_ids(temp_log.messages)
child_logs_by_parent = await find_child_logs_by_parent_id(log_id)
all_child_log_ids = list(set(delegated_log_ids) | set(child_logs_by_parent))
# Build child hierarchies
children = []
cumulative_counts = {
'input_tokens_sequence': individual_counts['input_tokens_sequence'],
'output_tokens_sequence': individual_counts['output_tokens_sequence'],
'input_tokens_total': individual_counts['input_tokens_total']
}
for child_id in all_child_log_ids:
child_hierarchy = await build_token_hierarchy(child_id, user, visited.copy())
if child_hierarchy:
children.append(child_hierarchy)
# Add child's cumulative counts to our cumulative counts
cumulative_counts['input_tokens_sequence'] += child_hierarchy['cumulative_counts']['input_tokens_sequence']
cumulative_counts['output_tokens_sequence'] += child_hierarchy['cumulative_counts']['output_tokens_sequence']
cumulative_counts['input_tokens_total'] += child_hierarchy['cumulative_counts']['input_tokens_total']
return {
'log_id': log_id,
'agent': log_data.get('agent', 'unknown'),
'individual_counts': individual_counts,
'cumulative_counts': cumulative_counts,
'children': children
}
[docs]
async def count_tokens_for_log_id(log_id: str, user: str = None, hierarchical: bool = False) -> Dict[str, int]:
"""
Count tokens for a chat log identified by log_id, including any delegated tasks.
Args:
log_id: The log ID to count tokens for
Returns:
Dictionary with token counts or None if log not found.
If hierarchical=True, includes 'hierarchy' key with tree structure.
If hierarchical=False (default), returns flat structure for backwards compatibility.
"""
# Find the chatlog file
chatlog_path = await find_chatlog_file(log_id)
if not chatlog_path:
return None
# If hierarchical structure is requested, build and return it
if hierarchical:
# Check cache first for hierarchical data
cached_counts = await get_cached_token_counts(log_id, chatlog_path)
if cached_counts and 'hierarchy' in cached_counts:
print(f"Using cached hierarchical token counts for {log_id}")
return cached_counts
print(f"Calculating hierarchical token counts for {log_id}")
hierarchy = await build_token_hierarchy(log_id, user)
if hierarchy:
result = {'hierarchy': hierarchy}
# Save hierarchical data to cache
await save_token_counts_to_cache(log_id, result)
return result
return None
# Check cache first
cached_counts = await get_cached_token_counts(log_id, chatlog_path)
if cached_counts:
print(f"Using cached token counts for {log_id}")
return cached_counts
print(f"Calculating token counts for {log_id}")
# Load the chat log
async with aiofiles.open(chatlog_path, 'r') as f:
content = await f.read()
log_data = json.loads(content)
# Get parent_log_id if it exists
parent_log_id = log_data.get('parent_log_id')
# Create a temporary ChatLog instance to count tokens
# Use provided user or try to determine from chatlog path or fallback to "system"
if user is None:
# Try to extract user from chatlog path: data/chat/{user}/{agent}/chatlog_{log_id}.json
try:
path_parts = chatlog_path.split(os.sep)
if len(path_parts) >= 4 and path_parts[-4] == 'chat':
extracted_user = path_parts[-3] # User is third from the end
user = extracted_user
print(f"Extracted user '{user}' from chatlog path: {chatlog_path}")
else:
user = "system" # Default fallback
except Exception as e:
print(f"Error extracting user from path {chatlog_path}: {e}")
user = "system" # Default fallback
temp_log = ChatLog(log_id=log_id, user=user, agent=log_data.get('agent', 'unknown'))
temp_log.messages = log_data.get('messages', [])
# Count tokens for this log
parent_counts = temp_log.count_tokens()
# Create combined counts (starting with parent counts)
combined_counts = {
'input_tokens_sequence': parent_counts['input_tokens_sequence'],
'output_tokens_sequence': parent_counts['output_tokens_sequence'],
'input_tokens_total': parent_counts['input_tokens_total']
}
# Find delegated task log IDs
delegated_log_ids = extract_delegate_task_log_ids(temp_log.messages)
# Also find child logs by parent_log_id
child_logs_by_parent = await find_child_logs_by_parent_id(log_id)
# Combine all child log IDs (delegated tasks and parent_log_id children)
all_child_log_ids = set(delegated_log_ids) | set(child_logs_by_parent)
# If this log has a parent_log_id, we should not double-count it
# (it will be counted as part of its parent's cumulative total)
# But we still want to count its own children
# Recursively count tokens for all child tasks
for child_id in all_child_log_ids:
delegated_counts = await count_tokens_for_log_id(child_id, user=user)
if delegated_counts:
combined_counts['input_tokens_sequence'] += delegated_counts['input_tokens_sequence']
combined_counts['output_tokens_sequence'] += delegated_counts['output_tokens_sequence']
combined_counts['input_tokens_total'] += delegated_counts['input_tokens_total']
# Create final result with both parent and combined counts
token_counts = {
# Parent session only counts
'input_tokens_sequence': parent_counts['input_tokens_sequence'],
'output_tokens_sequence': parent_counts['output_tokens_sequence'],
'input_tokens_total': parent_counts['input_tokens_total'],
# Combined counts (parent + all subtasks)
'combined_input_tokens_sequence': combined_counts['input_tokens_sequence'],
'combined_output_tokens_sequence': combined_counts['output_tokens_sequence'],
'combined_input_tokens_total': combined_counts['input_tokens_total']
}
# Save to cache
await save_token_counts_to_cache(log_id, token_counts)
return token_counts