LangChain 短期记忆¶
概述¶
短期记忆系统让 AI Agent 能够记住单次对话或线程中的先前交互信息。这对于构建能够理解上下文、学习用户偏好并保持连贯对话的智能应用至关重要。
核心概念¶
- 线程(Thread):组织多次交互的会话,类似电子邮件对话
- 检查点(Checkpointer):负责状态的持久化存储
- 状态管理:Agent 通过状态来维护对话历史和自定义信息
基础设置¶
1. 启用短期记忆¶
from langchain.agents import create_agent
from langgraph.checkpoint.memory import InMemorySaver
# 创建带有短期记忆的 Agent
agent = create_agent(
model="openai:gpt-4o",
tools=[get_user_info],
checkpointer=InMemorySaver(), # 启用内存检查点
)
# 使用线程ID来区分不同对话
result = agent.invoke(
{"messages": [{"role": "user", "content": "Hi! My name is Bob."}]},
{"configurable": {"thread_id": "1"}}, # 指定线程ID
)
2. 生产环境配置¶
from langchain.agents import create_agent
from langgraph.checkpoint.postgres import PostgresSaver
# 安装依赖:pip install langgraph-checkpoint-postgres
DB_URI = "postgresql://postgres:postgres@localhost:5442/postgres?sslmode=disable"
with PostgresSaver.from_conn_string(DB_URI) as checkpointer:
checkpointer.setup() # 自动创建数据库表
agent = create_agent(
model="openai:gpt-4o",
tools=[get_user_info],
checkpointer=checkpointer, # 使用 PostgreSQL 检查点
)
自定义 Agent 状态¶
扩展默认状态¶
from langchain.agents import create_agent, AgentState
from langgraph.checkpoint.memory import InMemorySaver
from typing import Optional, Dict, List
class CustomAgentState(AgentState):
"""自定义 Agent 状态"""
user_id: str
preferences: Dict[str, str]
conversation_topics: List[str]
last_active: Optional[str] = None
# 创建使用自定义状态的 Agent
agent = create_agent(
model="openai:gpt-4o",
tools=[get_user_info],
state_schema=CustomAgentState, # 使用自定义状态模式
checkpointer=InMemorySaver(),
)
# 调用时传入自定义状态
result = agent.invoke(
{
"messages": [{"role": "user", "content": "Hello"}],
"user_id": "user_123",
"preferences": {"theme": "dark", "language": "zh-CN"},
"conversation_topics": ["technology", "programming"]
},
{"configurable": {"thread_id": "1"}}
)
内存管理策略¶
1. 消息修剪(Trim Messages)¶
当对话历史过长时,修剪消息以适配上下文窗口。
from langchain.messages import RemoveMessage
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langchain.agents.middleware import before_model
from langgraph.runtime import Runtime
from typing import Any
@before_model
def trim_messages(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
"""保留最近几条消息以适配上下文窗口"""
messages = state["messages"]
# 如果消息数量不多,不需要修剪
if len(messages) <= 4:
return None
# 保留系统消息和最近的3条消息
system_messages = [msg for msg in messages if msg.type == "system"]
recent_messages = messages[-3:]
new_messages = system_messages + recent_messages
return {
"messages": [
RemoveMessage(id=REMOVE_ALL_MESSAGES), # 移除所有现有消息
*new_messages # 添加修剪后的消息
]
}
# 使用修剪中间件的 Agent
agent = create_agent(
model="openai:gpt-4o",
tools=[],
middleware=[trim_messages],
checkpointer=InMemorySaver(),
)
2. 消息删除(Delete Messages)¶
永久删除特定消息以管理对话历史。
from langchain.agents.middleware import after_model
from langchain.messages import RemoveMessage
@after_model
def delete_old_messages(state: AgentState, runtime: Runtime) -> dict | None:
"""删除旧消息以保持对话可管理"""
messages = state["messages"]
# 如果消息超过5条,删除最早的两条
if len(messages) > 5:
messages_to_remove = messages[:2]
return {
"messages": [RemoveMessage(id=msg.id) for msg in messages_to_remove]
}
return None
# 使用删除中间件的 Agent
agent = create_agent(
model="openai:gpt-4o",
tools=[],
middleware=[delete_old_messages],
checkpointer=InMemorySaver(),
)
3. 消息总结(Summarize Messages)¶
使用总结中间件自动总结长对话历史。
from langchain.agents import create_agent
from langchain.agents.middleware import SummarizationMiddleware
from langgraph.checkpoint.memory import InMemorySaver
checkpointer = InMemorySaver()
agent = create_agent(
model="openai:gpt-4o",
tools=[],
middleware=[
SummarizationMiddleware(
model="openai:gpt-4o-mini", # 使用更便宜的模型进行总结
max_tokens_before_summary=2000, # 在2000个token时触发总结
messages_to_keep=10, # 总结后保留最近10条消息
summary_prompt="请总结之前的对话,保留关键信息:",
)
],
checkpointer=checkpointer,
)
# 测试长对话
config = {"configurable": {"thread_id": "1"}}
agent.invoke({"messages": "Hi, my name is Bob"}, config)
agent.invoke({"messages": "I'm a software engineer from Beijing"}, config)
agent.invoke({"messages": "I enjoy hiking and reading books"}, config)
agent.invoke({"messages": "My favorite programming language is Python"}, config)
# 即使经过多次对话,Agent 仍然记得用户信息
final_response = agent.invoke({"messages": "Can you remind me what I told you about myself?"}, config)
print(final_response["messages"][-1].content)
访问和操作内存¶
1. 在工具中访问内存¶
from langchain.tools import tool, ToolRuntime
from langchain.agents import create_agent, AgentState
class UserState(AgentState):
user_profile: dict
conversation_count: int
@tool
def get_user_profile(runtime: ToolRuntime) -> str:
"""获取用户档案信息"""
state = runtime.state
user_profile = state.get("user_profile", {})
conversation_count = state.get("conversation_count", 0)
return f"用户档案: {user_profile}, 对话次数: {conversation_count}"
@tool
def update_user_preference(runtime: ToolRuntime, preference: str, value: str) -> str:
"""更新用户偏好"""
from langgraph.types import Command
# 更新状态
return Command(update={
"user_profile": {
**runtime.state.get("user_profile", {}),
preference: value
},
"conversation_count": runtime.state.get("conversation_count", 0) + 1
})
agent = create_agent(
model="openai:gpt-4o",
tools=[get_user_profile, update_user_preference],
state_schema=UserState,
checkpointer=InMemorySaver(),
)
2. 使用动态提示¶
from langchain.agents.middleware import dynamic_prompt, ModelRequest
from typing import TypedDict
class ConversationContext(TypedDict):
user_name: str
user_role: str
@dynamic_prompt
def personalized_system_prompt(request: ModelRequest) -> str:
"""基于用户上下文的动态系统提示"""
context = request.runtime.context
user_name = context.get("user_name", "用户")
user_role = context.get("user_role", "访客")
return f"""
你是一个有帮助的助手,正在与{user_name}对话。
{user_name}的身份是:{user_role}
请根据对话历史提供个性化的回应。
保持友好和专业的态度。
"""
def get_weather(city: str) -> str:
"""获取天气信息"""
return f"{city}的天气是晴朗的,25°C"
agent = create_agent(
model="openai:gpt-4o",
tools=[get_weather],
middleware=[personalized_system_prompt],
context_schema=ConversationContext,
)
# 使用上下文调用
result = agent.invoke(
{"messages": [{"role": "user", "content": "今天天气怎么样?"}]},
context=ConversationContext(user_name="张三", user_role="软件工程师")
)
3. Before Model 中间件¶
在模型调用前访问和修改状态。
from langchain.agents.middleware import before_model
from langchain.messages import SystemMessage
@before_model
def enhance_with_context(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
"""在模型调用前增强上下文"""
messages = state["messages"]
# 获取用户信息
user_id = state.get("user_id", "unknown")
preferences = state.get("preferences", {})
# 添加系统消息提供上下文
context_message = SystemMessage(content=f"""
当前用户ID: {user_id}
用户偏好: {preferences}
请根据以上信息提供个性化服务。
""")
# 将上下文消息添加到对话开始
enhanced_messages = [context_message] + messages
return {"messages": enhanced_messages}
agent = create_agent(
model="openai:gpt-4o",
tools=[],
middleware=[enhance_with_context],
state_schema=CustomAgentState,
checkpointer=InMemorySaver(),
)
4. After Model 中间件¶
在模型调用后处理响应和状态。
from langchain.agents.middleware import after_model
from langchain.messages import RemoveMessage
@after_model
def track_conversation_metrics(state: AgentState, runtime: Runtime) -> dict | None:
"""跟踪对话指标并清理敏感信息"""
messages = state["messages"]
# 更新对话统计
conversation_count = state.get("conversation_count", 0) + 1
last_active = datetime.now().isoformat()
# 检查并移除包含敏感信息的消息
sensitive_keywords = ["密码", "secret", "password", "token"]
messages_to_remove = []
for msg in messages:
if any(keyword in msg.content.lower() for keyword in sensitive_keywords):
messages_to_remove.append(msg)
updates = {
"conversation_count": conversation_count,
"last_active": last_active
}
if messages_to_remove:
updates["messages"] = [RemoveMessage(id=msg.id) for msg in messages_to_remove]
return updates
agent = create_agent(
model="openai:gpt-4o",
tools=[],
middleware=[track_conversation_metrics],
state_schema=CustomAgentState,
checkpointer=InMemorySaver(),
)
实际应用场景¶
场景1:个性化客户服务¶
from datetime import datetime
from typing import Dict, List, Optional
class CustomerServiceState(AgentState):
customer_id: str
ticket_history: List[Dict]
customer_tier: str # "standard", "premium", "vip"
last_issue: Optional[str] = None
satisfaction_score: Optional[int] = None
def create_customer_service_agent():
"""创建客户服务 Agent"""
@tool
def create_support_ticket(runtime: ToolRuntime, issue: str, priority: str) -> str:
"""创建支持工单"""
from langgraph.types import Command
ticket = {
"id": f"TICKET_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
"issue": issue,
"priority": priority,
"created_at": datetime.now().isoformat(),
"status": "open"
}
return Command(update={
"ticket_history": runtime.state.get("ticket_history", []) + [ticket],
"last_issue": issue
})
@tool
def get_customer_history(runtime: ToolRuntime) -> str:
"""获取客户历史记录"""
state = runtime.state
ticket_history = state.get("ticket_history", [])
customer_tier = state.get("customer_tier", "standard")
if not ticket_history:
return "这是该客户的第一次联系"
last_ticket = ticket_history[-1]
return f"""
客户等级: {customer_tier}
总工单数: {len(ticket_history)}
最近问题: {last_ticket['issue']}
最近工单状态: {last_ticket['status']}
"""
@before_model
def add_customer_context(state: CustomerServiceState, runtime: Runtime) -> dict | None:
"""添加客户上下文"""
customer_tier = state.get("customer_tier", "standard")
ticket_count = len(state.get("ticket_history", []))
tier_benefits = {
"standard": "标准支持(24小时内响应)",
"premium": "优先支持(4小时内响应)",
"vip": "专属支持(1小时内响应)"
}
context_msg = f"""
当前客户等级: {customer_tier}
支持级别: {tier_benefits.get(customer_tier, '标准支持')}
历史工单数量: {ticket_count}
"""
if state.get("last_issue"):
context_msg += f"\n最近报告的问题: {state['last_issue']}"
return {
"messages": [SystemMessage(content=context_msg)] + state["messages"]
}
return create_agent(
model="openai:gpt-4o",
tools=[create_support_ticket, get_customer_history],
middleware=[add_customer_context],
state_schema=CustomerServiceState,
checkpointer=InMemorySaver(),
)
# 使用示例
service_agent = create_customer_service_agent()
# 第一次交互
result1 = service_agent.invoke(
{
"messages": [{"role": "user", "content": "我的账户无法登录"}],
"customer_id": "cust_123",
"customer_tier": "premium",
"ticket_history": []
},
{"configurable": {"thread_id": "cust_123"}}
)
# 后续交互 - Agent 会记住客户历史
result2 = service_agent.invoke(
{
"messages": [{"role": "user", "content": "查看我的支持历史"}]
},
{"configurable": {"thread_id": "cust_123"}}
)
场景2:智能学习助手¶
class LearningAssistantState(AgentState):
student_level: str # "beginner", "intermediate", "advanced"
learning_topics: List[str]
completed_lessons: List[Dict]
weak_areas: List[str]
learning_style: str # "visual", "auditory", "kinesthetic"
def create_learning_assistant():
"""创建学习助手 Agent"""
@tool
def track_progress(runtime: ToolRuntime, topic: str, score: int) -> str:
"""跟踪学习进度"""
from langgraph.types import Command
lesson = {
"topic": topic,
"score": score,
"completed_at": datetime.now().isoformat()
}
completed_lessons = runtime.state.get("completed_lessons", []) + [lesson]
# 自动识别薄弱领域
weak_areas = []
if score < 70:
weak_areas = list(set(runtime.state.get("weak_areas", []) + [topic]))
return Command(update={
"completed_lessons": completed_lessons,
"weak_areas": weak_areas
})
@tool
def get_study_recommendations(runtime: ToolRuntime) -> str:
"""获取学习建议"""
state = runtime.state
weak_areas = state.get("weak_areas", [])
learning_style = state.get("learning_style", "visual")
student_level = state.get("student_level", "beginner")
recommendations = []
if weak_areas:
recommendations.append(f"需要加强的领域: {', '.join(weak_areas)}")
style_suggestions = {
"visual": "建议使用图表和视频学习",
"auditory": "建议收听讲解和参与讨论",
"kinesthetic": "建议通过实践练习学习"
}
recommendations.append(style_suggestions.get(learning_style, "多种方式结合学习"))
recommendations.append(f"适合{student_level}水平的学习材料")
return "\n".join(recommendations)
@dynamic_prompt
def personalized_learning_prompt(request: ModelRequest) -> str:
"""个性化学习提示"""
state = request.state
context = request.runtime.context
student_name = context.get("student_name", "同学")
learning_style = state.get("learning_style", "visual")
student_level = state.get("student_level", "beginner")
return f"""
你是一个耐心的学习助手,正在帮助{student_name}学习。
学生信息:
- 学习风格: {learning_style}
- 当前水平: {student_level}
- 已完成课程: {len(state.get('completed_lessons', []))}个
请根据学生的学习风格和水平提供个性化的指导。
对于{learning_style}型学习者,使用适合的教学方法。
"""
return create_agent(
model="openai:gpt-4o",
tools=[track_progress, get_study_recommendations],
middleware=[personalized_learning_prompt],
state_schema=LearningAssistantState,
checkpointer=InMemorySaver(),
)
# 使用示例
learning_agent = create_learning_assistant()
# 初始化学习状态
learning_agent.invoke(
{
"messages": [{"role": "user", "content": "我想学习Python编程"}],
"student_level": "beginner",
"learning_style": "visual",
"learning_topics": ["Python", "编程基础"],
"completed_lessons": [],
"weak_areas": []
},
{"configurable": {"thread_id": "student_123"}},
context={"student_name": "小明"}
)
场景3:电商推荐系统¶
class EcommerceState(AgentState):
user_id: str
browse_history: List[Dict]
purchase_history: List[Dict]
interests: List[str]
budget_range: str
preferred_categories: List[str]
def create_ecommerce_agent():
"""创建电商推荐 Agent"""
@tool
def track_browse_behavior(runtime: ToolRuntime, product: str, category: str) -> str:
"""跟踪浏览行为"""
from langgraph.types import Command
browse_record = {
"product": product,
"category": category,
"timestamp": datetime.now().isoformat()
}
# 更新浏览历史和兴趣
browse_history = runtime.state.get("browse_history", []) + [browse_record]
interests = list(set(runtime.state.get("interests", []) + [category]))
return Command(update={
"browse_history": browse_history,
"interests": interests
})
@tool
def get_personalized_recommendations(runtime: ToolRuntime) -> str:
"""获取个性化推荐"""
state = runtime.state
interests = state.get("interests", [])
budget_range = state.get("budget_range", "medium")
preferred_categories = state.get("preferred_categories", [])
# 基于用户行为生成推荐逻辑
recommendations = []
if interests:
recommendations.append(f"基于您的兴趣推荐: {', '.join(interests[:3])} 相关商品")
budget_map = {
"low": "经济实惠型",
"medium": "性价比型",
"high": "高端品质型"
}
recommendations.append(f"符合您{budget_map.get(budget_range, '中等')}预算的商品")
return "\n".join(recommendations)
@before_model
def enhance_with_shopping_context(state: EcommerceState, runtime: Runtime) -> dict | None:
"""增强购物上下文"""
interests = state.get("interests", [])
purchase_count = len(state.get("purchase_history", []))
browse_count = len(state.get("browse_history", []))
context_msg = f"""
购物助手上下文:
- 用户兴趣: {', '.join(interests) if interests else '尚未确定'}
- 浏览历史: {browse_count} 次
- 购买记录: {purchase_count} 次
- 预算范围: {state.get('budget_range', '未设置')}
"""
return {
"messages": [SystemMessage(content=context_msg)] + state["messages"]
}
return create_agent(
model="openai:gpt-4o",
tools=[track_browse_behavior, get_personalized_recommendations],
middleware=[enhance_with_shopping_context],
state_schema=EcommerceState,
checkpointer=InMemorySaver(),
)
最佳实践¶
1. 状态设计原则¶
class WellDesignedState(AgentState):
"""良好设计的状态类示例"""
# 必需的核心字段
user_id: str
# 会话相关字段
session_start: str
interaction_count: int = 0
# 业务相关字段
user_preferences: Dict[str, Any] = {}
recent_actions: List[str] = []
# 性能优化字段
last_summary: Optional[str] = None
tokens_used: int = 0
def should_summarize(self) -> bool:
"""判断是否需要总结"""
return self.interaction_count > 10 or self.tokens_used > 3000
2. 内存管理策略¶
def create_memory_optimized_agent():
"""创建内存优化的 Agent"""
@before_model
def smart_memory_management(state: AgentState, runtime: Runtime) -> dict | None:
"""智能内存管理"""
messages = state["messages"]
# 基于不同条件采取不同策略
if len(messages) > 20:
# 消息过多时进行总结
return {"messages": messages[-10:]} # 保留最近10条
elif state.get("tokens_used", 0) > 4000:
# Token 使用过多时修剪
return {"messages": messages[-8:]}
return None
@after_model
def update_usage_metrics(state: AgentState, runtime: Runtime) -> dict | None:
"""更新使用指标"""
# 估算 token 使用量(简化版)
message_content = " ".join([msg.content for msg in state["messages"]])
estimated_tokens = len(message_content) // 4
return {
"interaction_count": state.get("interaction_count", 0) + 1,
"tokens_used": state.get("tokens_used", 0) + estimated_tokens
}
return create_agent(
model="openai:gpt-4o",
tools=[],
middleware=[smart_memory_management, update_usage_metrics],
state_schema=WellDesignedState,
checkpointer=InMemorySaver(),
)
3. 错误处理和恢复¶
def create_robust_agent():
"""创建健壮的 Agent"""
@after_model
def handle_memory_errors(state: AgentState, runtime: Runtime) -> dict | None:
"""处理内存相关错误"""
try:
# 检查状态健康度
messages = state["messages"]
if len(messages) > 100:
# 消息过多,自动清理
return {"messages": messages[-20:]}
return None
except Exception as e:
# 发生错误时恢复到最后已知良好状态
print(f"内存处理错误: {e}")
return None # 保持当前状态
return create_agent(
model="openai:gpt-4o",
tools=[],
middleware=[handle_memory_errors],
checkpointer=InMemorySaver(),
)
总结¶
LangChain 的短期记忆系统提供了强大的对话状态管理能力:
- 灵活的状态设计:支持自定义状态字段
- 多种存储后端:内存、PostgreSQL 等
- 智能内存管理:修剪、删除、总结等策略
- 全方位访问:通过工具、中间件等访问和修改状态
- 生产级可靠性:错误处理和性能优化
通过合理使用短期记忆,可以构建出真正理解用户上下文、提供个性化体验的智能应用。