第03章 外部存储:把记忆持久化到数据库

第03章 外部存储:把记忆持久化到数据库

“进程内的记忆,进程死了就消失了。真正的记忆,需要超越进程的生命周期。” —— 分布式系统工程师


上章解决了单次对话内的记忆问题。但对话结束后,所有上下文依然会丢失。本章把记忆持久化到外部存储,让Agent真正"记住"用户。


3.1 持久化存储的架构设计

# 记忆持久化的两层架构

MEMORY_STORAGE_ARCHITECTURE = """
                    [Agent]
                      ↕
            ┌─────────────────┐
            │   记忆管理器     │
            │ (Memory Manager) │
            └────────┬────────┘
                     │
        ┌────────────┴────────────┐
        ↓                         ↓
  [Redis缓存]              [PostgreSQL主存储]
  热点记忆                  持久化对话历史
  最近5次对话                所有历史对话
  毫秒级访问                 全量搜索能力
  TTL=24小时                 永久保留

注:第04章还会增加 [pgvector语义搜索层]
"""

# 数据库Schema设计
CREATE_TABLES_SQL = """
-- 用户会话表
CREATE TABLE agent_sessions (
    id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
    user_id TEXT NOT NULL,
    session_id TEXT NOT NULL UNIQUE,
    started_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
    ended_at TIMESTAMPTZ,
    message_count INTEGER DEFAULT 0,
    metadata JSONB DEFAULT '{}'
);

-- 消息记录表
CREATE TABLE agent_messages (
    id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
    session_id TEXT NOT NULL REFERENCES agent_sessions(session_id),
    user_id TEXT NOT NULL,
    role TEXT NOT NULL CHECK (role IN ('user', 'assistant', 'system', 'tool')),
    content TEXT NOT NULL,
    tokens INTEGER,
    is_pinned BOOLEAN DEFAULT FALSE,
    metadata JSONB DEFAULT '{}',
    created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);

-- 索引
CREATE INDEX idx_messages_session ON agent_messages(session_id);
CREATE INDEX idx_messages_user ON agent_messages(user_id);
CREATE INDEX idx_messages_created ON agent_messages(created_at DESC);
CREATE INDEX idx_sessions_user ON agent_sessions(user_id);
"""

3.2 PostgreSQL持久化实现

# 使用SQLAlchemy 2 async实现记忆持久化

import asyncio
import uuid
from datetime import datetime, timezone
from typing import Optional

from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy import String, Text, Boolean, Integer, JSON, DateTime, select, and_

# ===== 数据模型 =====

class Base(DeclarativeBase):
    pass

class AgentSession(Base):
    __tablename__ = "agent_sessions"
    
    id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
    user_id: Mapped[str] = mapped_column(String, nullable=False, index=True)
    session_id: Mapped[str] = mapped_column(String, unique=True, nullable=False)
    started_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
    ended_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
    message_count: Mapped[int] = mapped_column(Integer, default=0)
    metadata_: Mapped[dict] = mapped_column("metadata", JSON, default=dict)

class AgentMessage(Base):
    __tablename__ = "agent_messages"
    
    id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
    session_id: Mapped[str] = mapped_column(String, nullable=False, index=True)
    user_id: Mapped[str] = mapped_column(String, nullable=False, index=True)
    role: Mapped[str] = mapped_column(String, nullable=False)
    content: Mapped[str] = mapped_column(Text, nullable=False)
    tokens: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
    is_pinned: Mapped[bool] = mapped_column(Boolean, default=False)
    metadata_: Mapped[dict] = mapped_column("metadata", JSON, default=dict)
    created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), index=True)

# ===== 存储服务 =====

class ConversationStore:
    """对话历史持久化存储"""
    
    def __init__(self, db_url: str):
        self.engine = create_async_engine(db_url, echo=False)
        self.session_factory = async_sessionmaker(self.engine, expire_on_commit=False)
    
    async def create_session(self, user_id: str, session_id: str) -> AgentSession:
        async with self.session_factory() as db:
            session = AgentSession(
                user_id=user_id,
                session_id=session_id,
            )
            db.add(session)
            await db.commit()
            await db.refresh(session)
            return session
    
    async def save_message(
        self,
        session_id: str,
        user_id: str,
        role: str,
        content: str,
        tokens: int = None,
        is_pinned: bool = False,
        metadata: dict = None,
    ) -> AgentMessage:
        """保存单条消息到数据库"""
        async with self.session_factory() as db:
            msg = AgentMessage(
                session_id=session_id,
                user_id=user_id,
                role=role,
                content=content,
                tokens=tokens,
                is_pinned=is_pinned,
                metadata_=metadata or {},
            )
            db.add(msg)
            
            # 更新session的消息计数
            result = await db.execute(
                select(AgentSession).where(AgentSession.session_id == session_id)
            )
            session = result.scalar_one_or_none()
            if session:
                session.message_count += 1
            
            await db.commit()
            await db.refresh(msg)
            return msg
    
    async def load_session_messages(
        self,
        session_id: str,
        limit: int = 50,
    ) -> list[AgentMessage]:
        """加载某次会话的消息历史"""
        async with self.session_factory() as db:
            result = await db.execute(
                select(AgentMessage)
                .where(AgentMessage.session_id == session_id)
                .order_by(AgentMessage.created_at.asc())
                .limit(limit)
            )
            return list(result.scalars().all())
    
    async def load_recent_sessions(
        self,
        user_id: str,
        n_sessions: int = 5,
        messages_per_session: int = 10,
    ) -> list[dict]:
        """加载用户最近N次会话的摘要"""
        async with self.session_factory() as db:
            # 获取最近的会话
            sessions_result = await db.execute(
                select(AgentSession)
                .where(AgentSession.user_id == user_id)
                .order_by(AgentSession.started_at.desc())
                .limit(n_sessions)
            )
            sessions = list(sessions_result.scalars().all())
            
            result = []
            for session in sessions:
                # 每次只获取消息(不并发:同一AsyncSession不能并发)
                msgs_result = await db.execute(
                    select(AgentMessage)
                    .where(AgentMessage.session_id == session.session_id)
                    .order_by(AgentMessage.created_at.asc())
                    .limit(messages_per_session)
                )
                messages = list(msgs_result.scalars().all())
                
                result.append({
                    "session_id": session.session_id,
                    "started_at": session.started_at.isoformat(),
                    "message_count": session.message_count,
                    "messages": [
                        {"role": m.role, "content": m.content[:200]}
                        for m in messages
                    ]
                })
            
            return result

3.3 Redis热缓存:快速加载最近对话

# Redis缓存:把最近的对话缓存到内存,避免每次都查数据库

import redis.asyncio as redis
import json

class RedisMemoryCache:
    """
    Redis热缓存
    - 存储用户最近5次会话(快速加载)
    - TTL 24小时(只需要最近的上下文)
    - 序列化为JSON
    """
    
    def __init__(self, redis_url: str = "redis://localhost:6379/0"):
        self.redis = redis.from_url(redis_url, decode_responses=True)
        self.ttl = 24 * 3600  # 24小时
    
    def _session_key(self, session_id: str) -> str:
        return f"agent:session:{session_id}"
    
    def _user_sessions_key(self, user_id: str) -> str:
        return f"agent:user:{user_id}:sessions"
    
    async def cache_message(self, session_id: str, user_id: str, role: str, content: str):
        """缓存一条新消息(追加到session的消息列表)"""
        key = self._session_key(session_id)
        
        message = json.dumps({"role": role, "content": content}, ensure_ascii=False)
        
        # 追加到消息列表
        await self.redis.rpush(key, message)
        # 设置TTL
        await self.redis.expire(key, self.ttl)
        
        # 记录用户的最新会话
        await self.redis.zadd(
            self._user_sessions_key(user_id),
            {session_id: __import__("time").time()}
        )
        await self.redis.expire(self._user_sessions_key(user_id), self.ttl)
    
    async def get_session_messages(self, session_id: str) -> list[dict]:
        """从缓存获取会话消息"""
        key = self._session_key(session_id)
        raw_messages = await self.redis.lrange(key, 0, -1)
        return [json.loads(m) for m in raw_messages]
    
    async def get_recent_sessions(self, user_id: str, n: int = 5) -> list[str]:
        """获取用户最近N个会话ID"""
        sessions = await self.redis.zrevrange(
            self._user_sessions_key(user_id), 0, n - 1
        )
        return sessions

# ===== 统一的记忆管理器:整合Redis和PostgreSQL =====

class MemoryManager:
    """
    统一的记忆管理器
    读:优先Redis(快),miss时从PostgreSQL(全量)
    写:同时写Redis和PostgreSQL
    """
    
    def __init__(self, store: ConversationStore, cache: RedisMemoryCache):
        self.store = store
        self.cache = cache
    
    async def save_message(self, session_id: str, user_id: str, role: str, content: str):
        """写入消息(双写)"""
        # 同时写入缓存和持久化存储
        await self.cache.cache_message(session_id, user_id, role, content)
        await self.store.save_message(session_id, user_id, role, content)
    
    async def load_session(self, session_id: str) -> list[dict]:
        """加载会话(缓存优先)"""
        # 先尝试缓存
        cached = await self.cache.get_session_messages(session_id)
        if cached:
            return cached
        
        # 缓存miss,从数据库加载
        messages = await self.store.load_session_messages(session_id)
        
        # 回填缓存
        for msg in messages:
            await self.cache.cache_message(session_id, msg.user_id, msg.role, msg.content)
        
        return [{"role": m.role, "content": m.content} for m in messages]

本章小结

五个核心认知:

  1. 持久化是跨会话记忆的前提:进程内的记忆不可靠;PostgreSQL给你持久性和SQL查询能力,Redis给你速度;两者组合是最实用的选择

  2. 双写策略:缓存+持久化同时写:每次存消息时同时写Redis和PostgreSQL;读时优先走Redis(毫秒级),miss时从PostgreSQL加载并回填缓存

  3. 数据Schema设计要考虑查询模式:按user_id和session_id建索引;message表的created_at DESC索引让"加载最近消息"非常快;不要过度设计,先用简单的Schema

  4. 同一个AsyncSession不能并发查询:(这是Market Vault架构的关键规则)不要在一个数据库session里用asyncio.gather并发查询;要顺序执行

  5. 缓存的TTL要根据业务设计:对话上下文24小时TTL是合理的起点;用户画像可能需要更长(7天或永久);频繁变化的数据TTL要短

核心行动

# 实施步骤:
# 1. 在你的数据库里创建 agent_sessions 和 agent_messages 表
# 2. 实例化 ConversationStore 和 RedisMemoryCache
# 3. 用 MemoryManager 替换你现有的消息列表
# 4. 在Agent循环中,每次用户发消息后调用 memory_manager.save_message()
# 5. 对话开始时,调用 load_session() 恢复上下文

本章提示词模板

模板一:设计记忆存储Schema

我需要为以下Agent设计记忆存储的数据库Schema:

Agent描述:[你的Agent是什么,做什么任务]
用户规模:[预计用户数量]
记忆类型:[对话历史/用户偏好/任务状态等]
查询模式:[最常见的查询是什么]

请帮我:
1. 设计核心表结构(建表SQL)
2. 建议必要的索引(考虑查询性能)
3. JSONB字段应该存什么?
4. 数据保留策略(多久的数据要保留?)
5. 预估存储需求(每天X用户X消息,一年需要多少存储)

模板二:分析记忆查询模式

我的Agent记忆系统需要支持以下查询:

1. 加载用户最近一次对话
2. 搜索历史中提到"[关键词]"的对话
3. 获取用户30天内的所有交互摘要
4. 查找特定日期范围的对话

请帮我:
1. 为每个查询设计最优的SQL
2. 哪些查询适合走Redis缓存(哪些不适合)?
3. 需要哪些索引?
4. 哪些查询需要考虑分页?
5. 高并发下(1000个用户同时使用)这些查询会有问题吗?

→ 第04章:向量数据库:用语义检索长期记忆