|
|
@@ -0,0 +1,246 @@
|
|
|
+#!/usr/bin/env python3
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
+"""SQLite 存储模块 - 完整版"""
|
|
|
+import sqlite3
|
|
|
+import json
|
|
|
+import os
|
|
|
+import datetime
|
|
|
+import tiktoken
|
|
|
+
|
|
|
+SCHEMA_VERSION = 5
|
|
|
+CLIS = ('codex', 'kiro', 'gemini', 'claude')
|
|
|
+
|
|
|
+_encoder = tiktoken.get_encoding("cl100k_base")
|
|
|
+
|
|
|
+def count_tokens(text: str) -> int:
|
|
|
+ return len(_encoder.encode(text)) if text else 0
|
|
|
+
|
|
|
+class ChatStorage:
|
|
|
+ def __init__(self, db_path: str):
|
|
|
+ self.db_path = db_path
|
|
|
+ os.makedirs(os.path.dirname(db_path) or '.', exist_ok=True)
|
|
|
+ self._init_db()
|
|
|
+
|
|
|
+ def _conn(self):
|
|
|
+ return sqlite3.connect(self.db_path)
|
|
|
+
|
|
|
+ def _init_db(self):
|
|
|
+ with self._conn() as conn:
|
|
|
+ conn.execute('''CREATE TABLE IF NOT EXISTS meta (key TEXT PRIMARY KEY, value TEXT)''')
|
|
|
+ for cli in CLIS:
|
|
|
+ conn.execute(f'''CREATE TABLE IF NOT EXISTS meta_{cli} (key TEXT PRIMARY KEY, value TEXT)''')
|
|
|
+ conn.execute('''
|
|
|
+ CREATE TABLE IF NOT EXISTS sessions (
|
|
|
+ file_path TEXT PRIMARY KEY,
|
|
|
+ session_id TEXT,
|
|
|
+ source TEXT NOT NULL,
|
|
|
+ cwd TEXT,
|
|
|
+ messages TEXT,
|
|
|
+ file_mtime INTEGER,
|
|
|
+ start_time TEXT,
|
|
|
+ token_count INTEGER DEFAULT 0
|
|
|
+ )
|
|
|
+ ''')
|
|
|
+ conn.execute('CREATE INDEX IF NOT EXISTS idx_source ON sessions(source)')
|
|
|
+ conn.execute('CREATE INDEX IF NOT EXISTS idx_session_id ON sessions(session_id)')
|
|
|
+ conn.execute('CREATE INDEX IF NOT EXISTS idx_start_time ON sessions(start_time)')
|
|
|
+ self._set_meta('meta', 'schema_version', str(SCHEMA_VERSION))
|
|
|
+
|
|
|
+ def _set_meta(self, table: str, key: str, value: str):
|
|
|
+ with self._conn() as conn:
|
|
|
+ conn.execute(f'INSERT OR REPLACE INTO {table} (key, value) VALUES (?, ?)', (key, value))
|
|
|
+
|
|
|
+ def _get_meta(self, table: str, key: str) -> str:
|
|
|
+ with self._conn() as conn:
|
|
|
+ row = conn.execute(f'SELECT value FROM {table} WHERE key = ?', (key,)).fetchone()
|
|
|
+ return row[0] if row else None
|
|
|
+
|
|
|
+ def update_cli_meta(self, cli: str, path: str, sessions: int, messages: int, tokens: int = None):
|
|
|
+ table = f'meta_{cli}'
|
|
|
+ now = datetime.datetime.now().isoformat()
|
|
|
+ # 顺序: path, sessions, messages, total_tokens, last_sync
|
|
|
+ self._set_meta(table, 'path', path)
|
|
|
+ self._set_meta(table, 'sessions', str(sessions))
|
|
|
+ self._set_meta(table, 'messages', str(messages))
|
|
|
+ self._set_meta(table, 'total_tokens', str(tokens or 0))
|
|
|
+ self._set_meta(table, 'last_sync', now)
|
|
|
+
|
|
|
+ def update_total_meta(self, sessions: int, messages: int, tokens: int = None):
|
|
|
+ now = datetime.datetime.now().isoformat()
|
|
|
+ self._set_meta('meta', 'total_sessions', str(sessions))
|
|
|
+ self._set_meta('meta', 'total_messages', str(messages))
|
|
|
+ if tokens is not None:
|
|
|
+ self._set_meta('meta', 'total_tokens', str(tokens))
|
|
|
+ self._set_meta('meta', 'last_sync', now)
|
|
|
+
|
|
|
+ def get_total_meta(self) -> dict:
|
|
|
+ return {
|
|
|
+ 'schema_version': int(self._get_meta('meta', 'schema_version') or 0),
|
|
|
+ 'total_sessions': int(self._get_meta('meta', 'total_sessions') or 0),
|
|
|
+ 'total_messages': int(self._get_meta('meta', 'total_messages') or 0),
|
|
|
+ 'last_sync': self._get_meta('meta', 'last_sync'),
|
|
|
+ }
|
|
|
+
|
|
|
+ def get_file_mtime(self, file_path: str) -> int:
|
|
|
+ with self._conn() as conn:
|
|
|
+ row = conn.execute('SELECT file_mtime FROM sessions WHERE file_path = ?', (file_path,)).fetchone()
|
|
|
+ return row[0] if row else 0
|
|
|
+
|
|
|
+ def upsert_session(self, session_id: str, source: str, file_path: str,
|
|
|
+ cwd: str, messages: list, file_mtime: int, start_time: str = None):
|
|
|
+ if file_path and not file_path.startswith('claude:') and not os.path.isabs(file_path):
|
|
|
+ file_path = os.path.abspath(file_path)
|
|
|
+
|
|
|
+ total_tokens = sum(count_tokens(msg.get('content', '')) for msg in messages)
|
|
|
+ if not start_time and messages:
|
|
|
+ start_time = messages[0].get('time')
|
|
|
+
|
|
|
+ messages_json = json.dumps(messages, ensure_ascii=False)
|
|
|
+ with self._conn() as conn:
|
|
|
+ conn.execute('''
|
|
|
+ INSERT INTO sessions (file_path, session_id, source, cwd, messages, file_mtime, start_time, token_count)
|
|
|
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
|
+ ON CONFLICT(file_path) DO UPDATE SET
|
|
|
+ session_id=excluded.session_id, messages=excluded.messages,
|
|
|
+ file_mtime=excluded.file_mtime, start_time=excluded.start_time, token_count=excluded.token_count
|
|
|
+ ''', (file_path, session_id, source, cwd, messages_json, file_mtime, start_time, total_tokens))
|
|
|
+
|
|
|
+ def get_cli_stats(self, cli: str) -> dict:
|
|
|
+ with self._conn() as conn:
|
|
|
+ sessions = conn.execute('SELECT COUNT(*) FROM sessions WHERE source = ?', (cli,)).fetchone()[0]
|
|
|
+ row = conn.execute('SELECT SUM(json_array_length(messages)) FROM sessions WHERE source = ?', (cli,)).fetchone()
|
|
|
+ messages = row[0] or 0
|
|
|
+ tokens = conn.execute('SELECT SUM(token_count) FROM sessions WHERE source = ?', (cli,)).fetchone()[0] or 0
|
|
|
+ return {'sessions': sessions, 'messages': messages, 'tokens': tokens}
|
|
|
+
|
|
|
+ def get_total_stats(self) -> dict:
|
|
|
+ with self._conn() as conn:
|
|
|
+ sessions = conn.execute('SELECT COUNT(*) FROM sessions').fetchone()[0]
|
|
|
+ row = conn.execute('SELECT SUM(json_array_length(messages)) FROM sessions').fetchone()
|
|
|
+ messages = row[0] or 0
|
|
|
+ tokens = conn.execute('SELECT SUM(token_count) FROM sessions').fetchone()[0] or 0
|
|
|
+ return {'sessions': sessions, 'messages': messages, 'tokens': tokens}
|
|
|
+
|
|
|
+ def get_token_stats(self) -> dict:
|
|
|
+ with self._conn() as conn:
|
|
|
+ rows = conn.execute('SELECT source, SUM(token_count) FROM sessions GROUP BY source').fetchall()
|
|
|
+ return {r[0]: r[1] or 0 for r in rows}
|
|
|
+
|
|
|
+ # === 清理孤立记录 ===
|
|
|
+ def prune(self) -> dict:
|
|
|
+ """删除源文件已不存在的记录"""
|
|
|
+ removed = {'codex': 0, 'kiro': 0, 'gemini': 0, 'claude': 0}
|
|
|
+ with self._conn() as conn:
|
|
|
+ rows = conn.execute('SELECT file_path, source FROM sessions').fetchall()
|
|
|
+ for fp, source in rows:
|
|
|
+ if fp.startswith('claude:'):
|
|
|
+ continue # Claude 使用虚拟路径
|
|
|
+ if not os.path.exists(fp):
|
|
|
+ conn.execute('DELETE FROM sessions WHERE file_path = ?', (fp,))
|
|
|
+ removed[source] = removed.get(source, 0) + 1
|
|
|
+ return removed
|
|
|
+
|
|
|
+ # === 查询 ===
|
|
|
+ def search(self, keyword: str, source: str = None, limit: int = 50) -> list:
|
|
|
+ """搜索消息内容"""
|
|
|
+ sql = "SELECT file_path, session_id, source, cwd, messages, start_time FROM sessions WHERE messages LIKE ?"
|
|
|
+ params = [f'%{keyword}%']
|
|
|
+ if source:
|
|
|
+ sql += " AND source = ?"
|
|
|
+ params.append(source)
|
|
|
+ sql += f" ORDER BY start_time DESC LIMIT {limit}"
|
|
|
+
|
|
|
+ results = []
|
|
|
+ with self._conn() as conn:
|
|
|
+ for row in conn.execute(sql, params):
|
|
|
+ results.append({
|
|
|
+ 'file_path': row[0], 'session_id': row[1], 'source': row[2],
|
|
|
+ 'cwd': row[3], 'messages': json.loads(row[4]), 'start_time': row[5]
|
|
|
+ })
|
|
|
+ return results
|
|
|
+
|
|
|
+ def get_session(self, file_path: str) -> dict:
|
|
|
+ """获取单个会话"""
|
|
|
+ with self._conn() as conn:
|
|
|
+ row = conn.execute(
|
|
|
+ 'SELECT file_path, session_id, source, cwd, messages, start_time, token_count FROM sessions WHERE file_path = ?',
|
|
|
+ (file_path,)
|
|
|
+ ).fetchone()
|
|
|
+ if not row:
|
|
|
+ return None
|
|
|
+ return {
|
|
|
+ 'file_path': row[0], 'session_id': row[1], 'source': row[2], 'cwd': row[3],
|
|
|
+ 'messages': json.loads(row[4]), 'start_time': row[5], 'token_count': row[6]
|
|
|
+ }
|
|
|
+
|
|
|
+ def list_sessions(self, source: str = None, limit: int = 100, offset: int = 0) -> list:
|
|
|
+ """列出会话"""
|
|
|
+ sql = "SELECT file_path, session_id, source, cwd, start_time, token_count FROM sessions"
|
|
|
+ params = []
|
|
|
+ if source:
|
|
|
+ sql += " WHERE source = ?"
|
|
|
+ params.append(source)
|
|
|
+ sql += f" ORDER BY start_time DESC LIMIT {limit} OFFSET {offset}"
|
|
|
+
|
|
|
+ results = []
|
|
|
+ with self._conn() as conn:
|
|
|
+ for row in conn.execute(sql, params):
|
|
|
+ results.append({
|
|
|
+ 'file_path': row[0], 'session_id': row[1], 'source': row[2],
|
|
|
+ 'cwd': row[3], 'start_time': row[4], 'token_count': row[5]
|
|
|
+ })
|
|
|
+ return results
|
|
|
+
|
|
|
+ # === 导出 ===
|
|
|
+ def export_json(self, output_path: str, source: str = None):
|
|
|
+ """导出为 JSON"""
|
|
|
+ sql = "SELECT file_path, session_id, source, cwd, messages, start_time, token_count FROM sessions"
|
|
|
+ params = []
|
|
|
+ if source:
|
|
|
+ sql += " WHERE source = ?"
|
|
|
+ params.append(source)
|
|
|
+ sql += " ORDER BY start_time"
|
|
|
+
|
|
|
+ data = []
|
|
|
+ with self._conn() as conn:
|
|
|
+ for row in conn.execute(sql, params):
|
|
|
+ data.append({
|
|
|
+ 'file_path': row[0], 'session_id': row[1], 'source': row[2], 'cwd': row[3],
|
|
|
+ 'messages': json.loads(row[4]), 'start_time': row[5], 'token_count': row[6]
|
|
|
+ })
|
|
|
+
|
|
|
+ with open(output_path, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(data, f, ensure_ascii=False, indent=2)
|
|
|
+ return len(data)
|
|
|
+
|
|
|
+ def export_csv(self, output_path: str, source: str = None):
|
|
|
+ """导出为 CSV(扁平化消息)"""
|
|
|
+ import csv
|
|
|
+ sql = "SELECT session_id, source, cwd, messages, start_time FROM sessions"
|
|
|
+ params = []
|
|
|
+ if source:
|
|
|
+ sql += " WHERE source = ?"
|
|
|
+ params.append(source)
|
|
|
+ sql += " ORDER BY start_time"
|
|
|
+
|
|
|
+ count = 0
|
|
|
+ with open(output_path, 'w', encoding='utf-8', newline='') as f:
|
|
|
+ writer = csv.writer(f)
|
|
|
+ writer.writerow(['session_id', 'source', 'cwd', 'time', 'role', 'content'])
|
|
|
+ with self._conn() as conn:
|
|
|
+ for row in conn.execute(sql, params):
|
|
|
+ session_id, src, cwd, msgs_json, _ = row
|
|
|
+ for msg in json.loads(msgs_json):
|
|
|
+ writer.writerow([session_id, src, cwd, msg.get('time', ''), msg.get('role', ''), msg.get('content', '')])
|
|
|
+ count += 1
|
|
|
+ return count
|
|
|
+
|
|
|
+ # === 获取所有文件路径(用于 prune 检查) ===
|
|
|
+ def get_all_file_paths(self, source: str = None) -> set:
|
|
|
+ sql = "SELECT file_path FROM sessions"
|
|
|
+ params = []
|
|
|
+ if source:
|
|
|
+ sql += " WHERE source = ?"
|
|
|
+ params.append(source)
|
|
|
+ with self._conn() as conn:
|
|
|
+ return {row[0] for row in conn.execute(sql, params)}
|