本笔记本演示了如何使用 LangGraph 和 Redis 管理智能体的短期和长期记忆。我们将探讨
我们将构建两种版本的旅行智能体,一种手动管理长期记忆,另一种使用 LLM 调用的工具来管理长期记忆。
这里有两张图,展示了两种智能体中使用的组件
pip install -q langchain-openai langgraph-checkpoint langgraph-checkpoint-redis "langchain-community>=0.2.11" tavily-python langchain-redis pydantic ulid
您必须为本课程添加一个包含计费信息的 OpenAI API 密钥。您还需要一个 Tavily API 密钥。在撰写本文时,Tavily API 密钥附带免费额度。
# NBVAL_SKIP
import getpass
import os
def _set_env(key: str):
if key not in os.environ:
os.environ[key] = getpass.getpass(f"{key}:")
_set_env("OPENAI_API_KEY")
# Uncomment this if you have a Tavily API key and want to
# use the web search tool.
# _set_env("TAVILY_API_KEY")
将以下单元格转换为 Python 以在 Colab 中运行。
# Exit if this is not running in Colab
if [ -z "$COLAB_RELEASE_TAG" ]; then
exit 0
fi
curl -fsSL https://packages.redis.io/gpg | sudo gpg --dearmor -o /usr/share/keyrings/redis-archive-keyring.gpg
echo "deb [signed-by=/usr/share/keyrings/redis-archive-keyring.gpg] https://packages.redis.io/deb $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/redis.list
sudo apt-get update > /dev/null 2>&1
sudo apt-get install redis-stack-server > /dev/null 2>&1
redis-stack-server --daemonize yes
有多种方法可以运行必要的 redis-stack 实例
使用 docker:docker run -d --name redis-stack-server -p 6379:6379 redis/redis-stack-server:latest
import os
from redis import Redis
# Use the environment variable if set, otherwise default to localhost
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379")
redis_client = Redis.from_url(REDIS_URL)
redis_client.ping()
智能体使用短期记忆和长期记忆。短期记忆和长期记忆的实现方式不同,智能体使用它们的方式也不同。让我们深入了解细节。我们很快会回到代码。
对于短期记忆,智能体使用 Redis 跟踪对话历史。由于这是一个 LangGraph 智能体,我们使用 RedisSaver 类来实现这一点。RedisSaver 是 LangGraph 所称的。您可以在LangGraph 文档中阅读有关 checkpointers 的更多信息。简而言之,它们为图中的每个节点存储状态,对于此智能体而言,这包括对话历史。
这里有一张图,展示了智能体如何使用 Redis 进行短期记忆。图中的每个节点(检索用户、响应、总结对话)都会将其“状态”持久化到 Redis。状态对象包含智能体针对当前线程的消息对话历史。
如果 Redis 持久化开启,Redis 会将短期记忆持久化到磁盘。这意味着如果您退出智能体并使用相同的线程 ID 和用户 ID 返回,您将恢复同一段对话。
对话历史可能会变得很长,并污染 LLM 的上下文窗口。为了管理这一点,在对话的每个“回合”之后,当对话超过可配置的阈值时,智能体会总结消息。Checkpointers 默认不会这样做,所以我们在图中创建了一个用于总结的节点。
注意:我们将在本笔记本稍后看到总结节点的示例代码。
除了对话历史之外,智能体还使用 RedisVL 将长期记忆存储在 Redis 的搜索索引中。这里有一张图,展示了涉及的组件
智能体跟踪两种类型的长期记忆
注意 如果您熟悉 CoALA 论文,这里的术语“情景记忆”和“语义记忆”与论文中的概念相同。CoALA 讨论了第三种类型的记忆,。在我们的示例中,我们将智能体代码库中用 Python 编码的逻辑视为其程序记忆。
我们使用几个 Pydantic 模型来表示长期记忆,包括存储到 Redis 之前和之后的状态
from datetime import datetime
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel, Field
import ulid
class MemoryType(str, Enum):
"""
The type of a long-term memory.
EPISODIC: User specific experiences and preferences
SEMANTIC: General knowledge on top of the user's preferences and LLM's
training data.
"""
EPISODIC = "episodic"
SEMANTIC = "semantic"
class Memory(BaseModel):
"""Represents a single long-term memory."""
content: str
memory_type: MemoryType
metadata: str
class Memories(BaseModel):
"""
A list of memories extracted from a conversation by an LLM.
NOTE: OpenAI's structured output requires us to wrap the list in an object.
"""
memories: List[Memory]
class StoredMemory(Memory):
"""A stored long-term memory"""
id: str # The redis key
memory_id: ulid.ULID = Field(default_factory=lambda: ulid.ULID())
created_at: datetime = Field(default_factory=datetime.now)
user_id: Optional[str] = None
thread_id: Optional[str] = None
memory_type: Optional[MemoryType] = None
class MemoryStrategy(str, Enum):
"""
Supported strategies for managing long-term memory.
This notebook supports two strategies for working with long-term memory:
TOOLS: The LLM decides when to store and retrieve long-term memories, using
tools (AKA, function-calling) to do so.
MANUAL: The agent manually retrieves long-term memories relevant to the
current conversation before sending every message and analyzes every
response to extract memories to store.
NOTE: In both cases, the agent runs a background thread to consolidate
memories, and a workflow step to summarize conversations after the history
grows past a threshold.
"""
TOOLS = "tools"
MANUAL = "manual"
# By default, we'll use the manual strategy
memory_strategy = MemoryStrategy.MANUAL
我们很快会回到这些模型,看看它们如何工作。
RedisSaver
类为我们处理了短期记忆存储的基础工作,所以我们这里无需做任何事情。
我们使用 RedisVL 存储和检索带有向量嵌入的长期记忆。这使得对过去的经历和知识进行语义搜索成为可能。
让我们设置一个新的搜索索引来存储和查询记忆
from redisvl.index import SearchIndex
from redisvl.schema.schema import IndexSchema
# Define schema for long-term memory index
memory_schema = IndexSchema.from_dict({
"index": {
"name": "agent_memories",
"prefix": "memory:",
"key_separator": ":",
"storage_type": "json",
},
"fields": [
{"name": "content", "type": "text"},
{"name": "memory_type", "type": "tag"},
{"name": "metadata", "type": "text"},
{"name": "created_at", "type": "text"},
{"name": "user_id", "type": "tag"},
{"name": "memory_id", "type": "tag"},
{
"name": "embedding",
"type": "vector",
"attrs": {
"algorithm": "flat",
"dims": 1536, # OpenAI embedding dimension
"distance_metric": "cosine",
"datatype": "float32",
},
},
],
}
)
# Create search index
try:
long_term_memory_index = SearchIndex(
schema=memory_schema, redis_client=redis_client, overwrite=True
)
long_term_memory_index.create()
print("Long-term memory index ready")
except Exception as e:
print(f"Error creating index: {e}")
现在我们在 Redis 中有了搜索索引,我们可以编写函数来存储和检索记忆。我们可以使用 RedisVL 来编写这些函数。
首先,我们将编写一个实用函数来检查索引中是否存在与给定记忆相似的记忆。稍后,我们可以使用此函数避免存储重复的记忆。
import logging
from redisvl.query import VectorRangeQuery
from redisvl.query.filter import Tag
from redisvl.utils.vectorize.text.openai import OpenAITextVectorizer
logger = logging.getLogger(__name__)
# If we have any memories that aren't associated with a user, we'll use this ID.
SYSTEM_USER_ID = "system"
openai_embed = OpenAITextVectorizer(model="text-embedding-ada-002")
# Change this to MemoryStrategy.TOOLS to use function-calling to store and
# retrieve memories.
memory_strategy = MemoryStrategy.MANUAL
def similar_memory_exists(
content: str,
memory_type: MemoryType,
user_id: str = SYSTEM_USER_ID,
thread_id: Optional[str] = None,
distance_threshold: float = 0.1,
) -> bool:
"""Check if a similar long-term memory already exists in Redis."""
query_embedding = openai_embed.embed(content)
filters = (Tag("user_id") == user_id) & (Tag("memory_type") == memory_type)
if thread_id:
filters = filters & (Tag("thread_id") == thread_id)
# Search for similar memories
vector_query = VectorRangeQuery(
vector=query_embedding,
num_results=1,
vector_field_name="embedding",
filter_expression=filters,
distance_threshold=distance_threshold,
return_fields=["id"],
)
results = long_term_memory_index.query(vector_query)
logger.debug(f"Similar memory search results: {results}")
if results:
logger.debug(
f"{len(results)} similar {'memory' if results.count == 1 else 'memories'} found. First: "
f"{results[0]['id']}. Skipping storage."
)
return True
return False
我们在存储记忆时将使用 similar_memory_exists()
函数
from datetime import datetime
from typing import List, Optional, Union
import ulid
def store_memory(
content: str,
memory_type: MemoryType,
user_id: str = SYSTEM_USER_ID,
thread_id: Optional[str] = None,
metadata: Optional[str] = None,
):
"""Store a long-term memory in Redis, avoiding duplicates."""
if metadata is None:
metadata = "{}"
logger.info(f"Preparing to store memory: {content}")
if similar_memory_exists(content, memory_type, user_id, thread_id):
logger.info("Similar memory found, skipping storage")
return
embedding = openai_embed.embed(content)
memory_data = {
"user_id": user_id or SYSTEM_USER_ID,
"content": content,
"memory_type": memory_type.value,
"metadata": metadata,
"created_at": datetime.now().isoformat(),
"embedding": embedding,
"memory_id": str(ulid.ULID()),
"thread_id": thread_id,
}
try:
long_term_memory_index.load([memory_data])
except Exception as e:
logger.error(f"Error storing memory: {e}")
return
logger.info(f"Stored {memory_type} memory: {content}")
既然我们正在存储记忆,我们就可以检索它们了
def retrieve_memories(
query: str,
memory_type: Union[Optional[MemoryType], List[MemoryType]] = None,
user_id: str = SYSTEM_USER_ID,
thread_id: Optional[str] = None,
distance_threshold: float = 0.1,
limit: int = 5,
) -> List[StoredMemory]:
"""Retrieve relevant memories from Redis"""
# Create vector query
logger.debug(f"Retrieving memories for query: {query}")
vector_query = VectorRangeQuery(
vector=openai_embed.embed(query),
return_fields=[
"content",
"memory_type",
"metadata",
"created_at",
"memory_id",
"thread_id",
"user_id",
],
num_results=limit,
vector_field_name="embedding",
dialect=2,
distance_threshold=distance_threshold,
)
base_filters = [f"@user_id:{{{user_id or SYSTEM_USER_ID}}}"]
if memory_type:
if isinstance(memory_type, list):
base_filters.append(f"@memory_type:{{{'|'.join(memory_type)}}}")
else:
base_filters.append(f"@memory_type:{{{memory_type.value}}}")
if thread_id:
base_filters.append(f"@thread_id:{{{thread_id}}}")
vector_query.set_filter(" ".join(base_filters))
# Execute search
results = long_term_memory_index.query(vector_query)
# Parse results
memories = []
for doc in results:
try:
memory = StoredMemory(
id=doc["id"],
memory_id=doc["memory_id"],
user_id=doc["user_id"],
thread_id=doc.get("thread_id", None),
memory_type=MemoryType(doc["memory_type"]),
content=doc["content"],
created_at=doc["created_at"],
metadata=doc["metadata"],
)
memories.append(memory)
except Exception as e:
logger.error(f"Error parsing memory: {e}")
continue
return memories
在进行 LLM 查询时,智能体可以通过以下两种方式(以及更多方式,但我们将讨论这两种)存储和检索相关的长期记忆
这两种方法都有权衡。
工具调用将是否存储记忆或查找相关记忆的决定留给了 LLM。这会增加请求的延迟。它通常会导致对 Redis 的调用次数减少,但有时也会错过检索潜在相关上下文以及/或从对话中提取相关记忆。
手动记忆管理会导致对 Redis 的调用次数增加,但会减少 LLM 的往返请求,从而降低延迟。手动提取记忆通常会比工具调用提取更多记忆,这会在 Redis 中存储更多数据,并应导致向 LLM 请求添加更多上下文。更多上下文意味着更高的上下文感知能力,但也会增加 token 消耗。
您可以通过更改 memory_strategy
变量来测试这两种方法。
使用手动记忆管理策略,我们将在用户与智能体每次交互后提取记忆。然后,在发送查询之前,我们将在未来的交互中检索这些记忆。
我们将在每次交互后手动调用 extract_memories
函数
from langchain_core.messages import HumanMessage
from langchain_core.runnables.config import RunnableConfig
from langchain_openai import ChatOpenAI
from langgraph.graph.message import MessagesState
class RuntimeState(MessagesState):
"""Agent state (just messages for now)"""
pass
memory_llm = ChatOpenAI(model="gpt-4o", temperature=0.3).with_structured_output(
Memories
)
def extract_memories(
last_processed_message_id: Optional[str],
state: RuntimeState,
config: RunnableConfig,
) -> Optional[str]:
"""Extract and store memories in long-term memory"""
logger.debug(f"Last message ID is: {last_processed_message_id}")
if len(state["messages"]) < 3: # Need at least a user message and agent response
logger.debug("Not enough messages to extract memories")
return last_processed_message_id
user_id = config.get("configurable", {}).get("user_id", None)
if not user_id:
logger.warning("No user ID found in config when extracting memories")
return last_processed_message_id
# Get the messages
messages = state["messages"]
# Find the newest message ID (or None if no IDs)
newest_message_id = None
for msg in reversed(messages):
if hasattr(msg, "id") and msg.id:
newest_message_id = msg.id
break
logger.debug(f"Newest message ID is: {newest_message_id}")
# If we've already processed up to this message ID, skip
if (
last_processed_message_id
and newest_message_id
and last_processed_message_id == newest_message_id
):
logger.debug(f"Already processed messages up to ID {newest_message_id}")
return last_processed_message_id
# Find the index of the message with last_processed_message_id
start_index = 0
if last_processed_message_id:
for i, msg in enumerate(messages):
if hasattr(msg, "id") and msg.id == last_processed_message_id:
start_index = i + 1 # Start processing from the next message
break
# Check if there are messages to process
if start_index >= len(messages):
logger.debug("No new messages to process since last processed message")
return newest_message_id
# Get only the messages after the last processed message
messages_to_process = messages[start_index:]
# If there are not enough messages to process, include some context
if len(messages_to_process) < 3 and start_index > 0:
# Include up to 3 messages before the start_index for context
context_start = max(0, start_index - 3)
messages_to_process = messages[context_start:]
# Format messages for the memory agent
message_history = "\n".join(
[
f"{'User' if isinstance(msg, HumanMessage) else 'Assistant'}: {msg.content}"
for msg in messages_to_process
]
)
prompt = f"""
You are a long-memory manager. Your job is to analyze this message history
and extract information that might be useful in future conversations.
Extract two types of memories:
1. EPISODIC: Personal experiences and preferences specific to this user
Example: "User prefers window seats" or "User had a bad experience in Paris"
2. SEMANTIC: General facts and knowledge about travel that could be useful
Example: "The best time to visit Japan is during cherry blossom season in April"
For each memory, provide:
- Type: The memory type (EPISODIC/SEMANTIC)
- Content: The actual information to store
- Metadata: Relevant tags and context (as JSON)
IMPORTANT RULES:
1. Only extract information that would be genuinely useful for future interactions.
2. Do not extract procedural knowledge - that is handled by the system's built-in tools and prompts.
3. You are a large language model, not a human - do not extract facts that you already know.
Message history:
{message_history}
Extracted memories:
"""
memories_to_store: Memories = memory_llm.invoke([HumanMessage(content=prompt)]) # type: ignore
# Store each extracted memory
for memory_data in memories_to_store.memories:
store_memory(
content=memory_data.content,
memory_type=memory_data.memory_type,
user_id=user_id,
metadata=memory_data.metadata,
)
# Return data with the newest processed message ID
return newest_message_id
我们将在后台线程中使用此函数。我们将在手动记忆模式下启动该线程,但在工具模式下不启动,我们将它作为一个工作进程运行,该进程从一个 Queue
中拉取消息历史进行处理
import time
from queue import Queue
DEFAULT_MEMORY_WORKER_INTERVAL = 5 * 60 # 5 minutes
DEFAULT_MEMORY_WORKER_BACKOFF_INTERVAL = 10 * 60 # 10 minutes
def memory_worker(
memory_queue: Queue,
user_id: str,
interval: int = DEFAULT_MEMORY_WORKER_INTERVAL,
backoff_interval: int = DEFAULT_MEMORY_WORKER_BACKOFF_INTERVAL,
):
"""Worker function that processes long-term memory extraction requests"""
key = f"memory_worker:{user_id}:last_processed_message_id"
last_processed_message_id = redis_client.get(key)
logger.debug(f"Last processed message ID: {last_processed_message_id}")
last_processed_message_id = (
str(last_processed_message_id) if last_processed_message_id else None
)
while True:
try:
# Get the next state and config from the queue (blocks until an item is available)
state, config = memory_queue.get()
# Extract long-term memories from the conversation history
last_processed_message_id = extract_memories(
last_processed_message_id, state, config
)
logger.debug(
f"Memory worker extracted memories. Last processed message ID: {last_processed_message_id}"
)
if last_processed_message_id:
logger.debug(
f"Setting last processed message ID: {last_processed_message_id}"
)
redis_client.set(key, last_processed_message_id)
# Mark the task as done
memory_queue.task_done()
logger.debug("Memory extraction completed for queue item")
# Wait before processing next item
time.sleep(interval)
except Exception as e:
# Wait before processing next item after an error
logger.exception(f"Error in memory worker thread: {e}")
time.sleep(backoff_interval)
# NOTE: We'll actually start the worker thread later, in the main loop.
对于用户与智能体的每一次交互,我们都会查询相关记忆,并使用 retrieve_relevant_memories()
将它们添加到 LLM 提示中。
注意:我们只在“手动”记忆管理策略中运行此节点。如果使用“工具”,LLM 将决定何时检索记忆。
def retrieve_relevant_memories(
state: RuntimeState, config: RunnableConfig
) -> RuntimeState:
"""Retrieve relevant memories based on the current conversation."""
if not state["messages"]:
logger.debug("No messages in state")
return state
latest_message = state["messages"][-1]
if not isinstance(latest_message, HumanMessage):
logger.debug("Latest message is not a HumanMessage: ", latest_message)
return state
user_id = config.get("configurable", {}).get("user_id", SYSTEM_USER_ID)
query = str(latest_message.content)
relevant_memories = retrieve_memories(
query=query,
memory_type=[MemoryType.EPISODIC, MemoryType.SEMANTIC],
limit=5,
user_id=user_id,
distance_threshold=0.3,
)
logger.debug(f"All relevant memories: {relevant_memories}")
# We'll augment the latest human message with the relevant memories.
if relevant_memories:
memory_context = "\n\n### Relevant memories from previous conversations:\n"
# Group by memory type
memory_types = {
MemoryType.EPISODIC: "User Preferences & History",
MemoryType.SEMANTIC: "Travel Knowledge",
}
for mem_type, type_label in memory_types.items():
memories_of_type = [
m for m in relevant_memories if m.memory_type == mem_type
]
if memories_of_type:
memory_context += f"\n**{type_label}**:\n"
for mem in memories_of_type:
memory_context += f"- {mem.content}\n"
augmented_message = HumanMessage(content=f"{query}\n{memory_context}")
state["messages"][-1] = augmented_message
logger.debug(f"Augmented message: {augmented_message.content}")
return state.copy()
这是我们看到的第一个代表我们将构建的 LangGraph 图中节点的函数。作为节点表示,此函数接收一个包含图运行时状态的状态对象,对话历史就存储在此状态对象中。其 config 参数包含用户和线程 ID 等数据。
这将是我们稍后组装的图的起始节点。当用户使用消息调用图时,我们首先要做的(在使用“手动”记忆策略时)是使用潜在相关的记忆增强该消息。
现在我们已经定义了存储函数,我们可以创建工具。稍后我们将需要这些工具来设置我们的智能体。这些工具仅在智能体以“工具”记忆管理模式运行时使用。
from langchain_core.tools import tool
from typing import Dict, Optional
@tool
def store_memory_tool(
content: str,
memory_type: MemoryType,
metadata: Optional[Dict[str, str]] = None,
config: Optional[RunnableConfig] = None,
) -> str:
"""
Store a long-term memory in the system.
Use this tool to save important information about user preferences,
experiences, or general knowledge that might be useful in future
interactions.
"""
config = config or RunnableConfig()
user_id = config.get("user_id", SYSTEM_USER_ID)
thread_id = config.get("thread_id")
try:
# Store in long-term memory
store_memory(
content=content,
memory_type=memory_type,
user_id=user_id,
thread_id=thread_id,
metadata=str(metadata) if metadata else None,
)
return f"Successfully stored {memory_type} memory: {content}"
except Exception as e:
return f"Error storing memory: {str(e)}"
@tool
def retrieve_memories_tool(
query: str,
memory_type: List[MemoryType],
limit: int = 5,
config: Optional[RunnableConfig] = None,
) -> str:
"""
Retrieve long-term memories relevant to the query.
Use this tool to access previously stored information about user
preferences, experiences, or general knowledge.
"""
config = config or RunnableConfig()
user_id = config.get("user_id", SYSTEM_USER_ID)
try:
# Get long-term memories
stored_memories = retrieve_memories(
query=query,
memory_type=memory_type,
user_id=user_id,
limit=limit,
distance_threshold=0.3,
)
# Format the response
response = []
if stored_memories:
response.append("Long-term memories:")
for memory in stored_memories:
response.append(f"- [{memory.memory_type}] {memory.content}")
return "\n".join(response) if response else "No relevant memories found."
except Exception as e:
return f"Error retrieving memories: {str(e)}"
因为我们使用了针对不同目的配置的不同 LLM 对象以及一个预构建的 ReAct 智能体,我们需要一个节点来调用智能体并返回响应。但在调用智能体之前,我们需要先进行设置。这需要定义智能体所需的工具。
import json
from typing import Dict, List, Optional, Tuple, Union
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.callbacks.manager import CallbackManagerForToolRun
from langchain_core.messages import AIMessage, AIMessageChunk, SystemMessage
from langgraph.prebuilt.chat_agent_executor import create_react_agent
from langgraph.checkpoint.redis import RedisSaver
class CachingTavilySearchResults(TavilySearchResults):
"""
An interface to Tavily search that caches results in Redis.
Caching the results of the web search allows us to avoid rate limiting,
improve latency, and reduce costs.
"""
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Tuple[Union[List[Dict[str, str]], str], Dict]:
"""Use the tool."""
cache_key = f"tavily_search:{query}"
cached_result: Optional[str] = redis_client.get(cache_key) # type: ignore
if cached_result:
return json.loads(cached_result), {}
else:
result, raw_results = super()._run(query, run_manager)
redis_client.set(cache_key, json.dumps(result), ex=60 * 60)
return result, raw_results
# Create a checkpoint saver for short-term memory. This keeps track of the
# conversation history for each thread. Later, we'll continually summarize the
# conversation history to keep the context window manageable, while we also
# extract long-term memories from the conversation history to store in the
# long-term memory index.
redis_saver = RedisSaver(redis_client=redis_client)
redis_saver.setup()
# Configure an LLM for the agent with a more creative temperature.
llm = ChatOpenAI(model="gpt-4o", temperature=0.7)
# Uncomment these lines if you have a Tavily API key and want to use the web
# search tool. The agent is much more useful with this tool.
# web_search_tool = CachingTavilySearchResults(max_results=2)
# base_tools = [web_search_tool]
base_tools = []
if memory_strategy == MemoryStrategy.TOOLS:
tools = base_tools + [store_memory_tool, retrieve_memories_tool]
elif memory_strategy == MemoryStrategy.MANUAL:
tools = base_tools
travel_agent = create_react_agent(
model=llm,
tools=tools,
checkpointer=redis_saver, # Short-term memory: the conversation history
prompt=SystemMessage(
content="""
You are a travel assistant helping users plan their trips. You remember user preferences
and provide personalized recommendations based on past interactions.
You have access to the following types of memory:
1. Short-term memory: The current conversation thread
2. Long-term memory:
- Episodic: User preferences and past trip experiences (e.g., "User prefers window seats")
- Semantic: General knowledge about travel destinations and requirements
Your procedural knowledge (how to search, book flights, etc.) is built into your tools and prompts.
Always be helpful, personal, and context-aware in your responses.
"""
),
)
现在我们可以编写调用智能体并响应用户的节点了
def respond_to_user(state: RuntimeState, config: RunnableConfig) -> RuntimeState:
"""Invoke the travel agent to generate a response."""
human_messages = [m for m in state["messages"] if isinstance(m, HumanMessage)]
if not human_messages:
logger.warning("No HumanMessage found in state")
return state
try:
for result in travel_agent.stream(
{"messages": state["messages"]}, config=config, stream_mode="messages"
):
result_messages = result.get("messages", [])
ai_messages = [
m
for m in result_messages
if isinstance(m, AIMessage) or isinstance(m, AIMessageChunk)
]
if ai_messages:
agent_response = ai_messages[-1]
# Append only the agent's response to the original state
state["messages"].append(agent_response)
except Exception as e:
logger.error(f"Error invoking travel agent: {e}")
agent_response = AIMessage(
content="I'm sorry, I encountered an error processing your request."
)
return state
我们一直在关注长期记忆,但现在让我们回到短期记忆一会儿。使用 RedisSaver
,LangGraph 将自动管理我们的消息历史。但是,消息历史会无限增长,直到它超出 LLM 的 token 上下文窗口。
为了解决这个问题,如果对话历史超过了阈值,我们将在图中添加一个节点来总结对话。
from langchain_core.messages import RemoveMessage
# An LLM configured for summarization.
summarizer = ChatOpenAI(model="gpt-4o", temperature=0.3)
# The number of messages after which we'll summarize the conversation.
MESSAGE_SUMMARIZATION_THRESHOLD = 10
def summarize_conversation(
state: RuntimeState, config: RunnableConfig
) -> Optional[RuntimeState]:
"""
Summarize a list of messages into a concise summary to reduce context length
while preserving important information.
"""
messages = state["messages"]
current_message_count = len(messages)
if current_message_count < MESSAGE_SUMMARIZATION_THRESHOLD:
logger.debug(f"Not summarizing conversation: {current_message_count}")
return state
system_prompt = """
You are a conversation summarizer. Create a concise summary of the previous
conversation between a user and a travel assistant.
The summary should:
1. Highlight key topics, preferences, and decisions
2. Include any specific trip details (destinations, dates, preferences)
3. Note any outstanding questions or topics that need follow-up
4. Be concise but informative
Format your summary as a brief narrative paragraph.
"""
message_content = "\n".join(
[
f"{'User' if isinstance(msg, HumanMessage) else 'Assistant'}: {msg.content}"
for msg in messages
]
)
# Invoke the summarizer
summary_messages = [
SystemMessage(content=system_prompt),
HumanMessage(
content=f"Please summarize this conversation:\n\n{message_content}"
),
]
summary_response = summarizer.invoke(summary_messages)
logger.info(f"Summarized {len(messages)} messages into a conversation summary")
summary_message = SystemMessage(
content=f"""
Summary of the conversation so far:
{summary_response.content}
Please continue the conversation based on this summary and the recent messages.
"""
)
remove_messages = [
RemoveMessage(id=msg.id) for msg in messages if msg.id is not None
]
state["messages"] = [ # type: ignore
*remove_messages,
summary_message,
state["messages"][-1],
]
return state.copy()
现在是组装我们的图的时候了。
from langgraph.graph import StateGraph, END, START
workflow = StateGraph(RuntimeState)
workflow.add_node("respond", respond_to_user)
workflow.add_node("summarize_conversation", summarize_conversation)
if memory_strategy == MemoryStrategy.MANUAL:
# In manual memory mode, we'll retrieve relevant memories before
# responding to the user, and then augment the user's message with the
# relevant memories.
workflow.add_node("retrieve_memories", retrieve_relevant_memories)
workflow.add_edge(START, "retrieve_memories")
workflow.add_edge("retrieve_memories", "respond")
else:
# In tool-calling mode, we'll respond to the user and let the LLM
# decide when to retrieve and store memories, using tool calls.
workflow.add_edge(START, "respond")
# Regardless of memory strategy, we'll summarize the conversation after
# responding to the user, to keep the context window manageable.
workflow.add_edge("respond", "summarize_conversation")
workflow.add_edge("summarize_conversation", END)
# Finally, compile the graph.
graph = workflow.compile(checkpointer=redis_saver)
我们几乎准备好创建运行我们图的主循环了。不过,首先,让我们创建一个工作进程,它会定期使用语义搜索来整合相似的记忆。稍后,我们将在主循环中以后台线程的方式运行该工作进程。
from redisvl.query import FilterQuery
def consolidate_memories(user_id: str, batch_size: int = 10):
"""
Periodically merge similar long-term memories for a user.
"""
logger.info(f"Starting memory consolidation for user {user_id}")
# For each memory type, consolidate separately
for memory_type in MemoryType:
all_memories = []
# Get all memories of this type for the user
of_type_for_user = (Tag("user_id") == user_id) & (
Tag("memory_type") == memory_type
)
filter_query = FilterQuery(filter_expression=of_type_for_user)
for batch in long_term_memory_index.paginate(filter_query, page_size=batch_size):
all_memories.extend(batch)
all_memories = long_term_memory_index.query(filter_query)
if not all_memories:
continue
# Group similar memories
processed_ids = set()
for memory in all_memories:
if memory["id"] in processed_ids:
continue
memory_embedding = memory["embedding"]
vector_query = VectorRangeQuery(
vector=memory_embedding,
num_results=10,
vector_field_name="embedding",
filter_expression=of_type_for_user
& (Tag("memory_id") != memory["memory_id"]),
distance_threshold=0.1,
return_fields=[
"content",
"metadata",
],
)
similar_memories = long_term_memory_index.query(vector_query)
# If we found similar memories, consolidate them
if similar_memories:
combined_content = memory["content"]
combined_metadata = memory["metadata"]
if combined_metadata:
try:
combined_metadata = json.loads(combined_metadata)
except Exception as e:
logger.error(f"Error parsing metadata: {e}")
combined_metadata = {}
for similar in similar_memories:
# Merge the content of similar memories
combined_content += f" {similar['content']}"
if similar["metadata"]:
try:
similar_metadata = json.loads(similar["metadata"])
except Exception as e:
logger.error(f"Error parsing metadata: {e}")
similar_metadata = {}
combined_metadata = {**combined_metadata, **similar_metadata}
# Create a consolidated memory
new_metadata = {
"consolidated": True,
"source_count": len(similar_memories) + 1,
**combined_metadata,
}
consolidated_memory = {
"content": summarize_memories(combined_content, memory_type),
"memory_type": memory_type.value,
"metadata": json.dumps(new_metadata),
"user_id": user_id,
}
# Delete the old memories
delete_memory(memory["id"])
for similar in similar_memories:
delete_memory(similar["id"])
# Store the new consolidated memory
store_memory(
content=consolidated_memory["content"],
memory_type=memory_type,
user_id=user_id,
metadata=consolidated_memory["metadata"],
)
logger.info(
f"Consolidated {len(similar_memories) + 1} memories into one"
)
def delete_memory(memory_id: str):
"""Delete a memory from Redis"""
try:
result = long_term_memory_index.drop_keys([memory_id])
except Exception as e:
logger.error(f"Deleting memory {memory_id} failed: {e}")
if result == 0:
logger.debug(f"Deleting memory {memory_id} failed: memory not found")
else:
logger.info(f"Deleted memory {memory_id}")
def summarize_memories(combined_content: str, memory_type: MemoryType) -> str:
"""Use the LLM to create a concise summary of similar memories"""
try:
system_prompt = f"""
You are a memory consolidation assistant. Your task is to create a single,
concise memory from these similar memory fragments. The new memory should
be a {memory_type.value} memory.
Combine the information without repetition while preserving all important details.
"""
messages = [
SystemMessage(content=system_prompt),
HumanMessage(
content=f"Consolidate these similar memories into one:\n\n{combined_content}"
),
]
response = summarizer.invoke(messages)
return str(response.content)
except Exception as e:
logger.error(f"Error summarizing memories: {e}")
# Fall back to just using the combined content
return combined_content
def memory_consolidation_worker(user_id: str):
"""
Worker that periodically consolidates memories for the active user.
NOTE: In production, this would probably use a background task framework, such
as rq or Celery, and run on a schedule.
"""
while True:
try:
consolidate_memories(user_id)
# Run every 10 minutes
time.sleep(10 * 60)
except Exception as e:
logger.exception(f"Error in memory consolidation worker: {e}")
# If there's an error, wait an hour and try again
time.sleep(60 * 60)
现在我们可以将所有内容组合起来,运行主循环了。
运行此单元格应会要求您输入 OpenAI 和 Tavily 密钥,然后是用户名和线程 ID。您将进入一个循环,在其中可以输入查询,并在以下单元格下方看到智能体打印的响应。
import threading
def main(thread_id: str = "book_flight", user_id: str = "demo_user"):
"""Main interaction loop for the travel agent"""
print("Welcome to the Travel Assistant! (Type 'exit' to quit)")
config = RunnableConfig(configurable={"thread_id": thread_id, "user_id": user_id})
state = RuntimeState(messages=[])
# If we're using the manual memory strategy, we need to create a queue for
# memory processing and start a worker thread. After every 'round' of a
# conversation, the main loop will add the current state and config to the
# queue for memory processing.
if memory_strategy == MemoryStrategy.MANUAL:
# Create a queue for memory processing
memory_queue = Queue()
# Start a worker thread that will process memory extraction tasks
memory_thread = threading.Thread(
target=memory_worker, args=(memory_queue, user_id), daemon=True
)
memory_thread.start()
# We always run consolidation in the background, regardless of memory strategy.
consolidation_thread = threading.Thread(
target=memory_consolidation_worker, args=(user_id,), daemon=True
)
consolidation_thread.start()
while True:
user_input = input("\nYou (type 'quit' to quit): ")
if not user_input:
continue
if user_input.lower() in ["exit", "quit"]:
print("Thank you for using the Travel Assistant. Goodbye!")
break
state["messages"].append(HumanMessage(content=user_input))
try:
# Process user input through the graph
for result in graph.stream(state, config=config, stream_mode="values"):
state = RuntimeState(**result)
logger.debug(f"# of messages after run: {len(state['messages'])}")
# Find the most recent AI message, so we can print the response
ai_messages = [m for m in state["messages"] if isinstance(m, AIMessage)]
if ai_messages:
message = ai_messages[-1].content
else:
logger.error("No AI messages after run")
message = "I'm sorry, I couldn't process your request properly."
# Add the error message to the state
state["messages"].append(AIMessage(content=message))
print(f"\nAssistant: {message}")
# Add the current state to the memory processing queue
if memory_strategy == MemoryStrategy.MANUAL:
memory_queue.put((state.copy(), config))
except Exception as e:
logger.exception(f"Error processing request: {e}")
error_message = "I'm sorry, I encountered an error processing your request."
print(f"\nAssistant: {error_message}")
# Add the error message to the state
state["messages"].append(AIMessage(content=error_message))
try:
user_id = input("Enter a user ID: ") or "demo_user"
thread_id = input("Enter a thread ID: ") or "demo_thread"
except Exception:
# If we're running in CI, we don't have a terminal to input from, so just exit
exit()
else:
main(thread_id, user_id)
想构建您自己的智能体?试试LangGraph 快速入门。然后添加我们的Redis checkpointer,为您的智能体提供快速、持久的记忆。
使用 Redis 管理 AI 智能体的记忆,您可以构建一个灵活、可扩展的系统,能够快速存储和检索记忆。查阅下面的资源,立即开始使用 Redis 进行构建,或联系我们的团队讨论 AI 智能体。
Redis Cloud:部署 Redis 最简单的方法——在 AWS、Azure 或 GCP 上免费试用。