学习

什么是智能体记忆?使用 LangGraph 和 Redis 的示例

本笔记本演示了如何使用 LangGraph 和 Redis 管理智能体的短期和长期记忆。我们将探讨

  1. 1.使用 LangGraph 的 checkpointer 管理短期记忆
  2. 2.使用 RedisVL 存储和检索长期记忆
  3. 3.手动管理长期记忆与暴露工具访问(即函数调用)的对比
  4. 4.通过总结管理对话历史大小
  5. 5.记忆整合

我们将构建什么#

我们将构建两种版本的旅行智能体,一种手动管理长期记忆,另一种使用 LLM 调用的工具来管理长期记忆。

这里有两张图,展示了两种智能体中使用的组件

设置#

软件包
pip install -q langchain-openai langgraph-checkpoint langgraph-checkpoint-redis "langchain-community>=0.2.11" tavily-python langchain-redis pydantic ulid

必需的 API 密钥#

您必须为本课程添加一个包含计费信息的 OpenAI API 密钥。您还需要一个 Tavily API 密钥。在撰写本文时,Tavily API 密钥附带免费额度。

# NBVAL_SKIP
import getpass
import os

def _set_env(key: str):
    if key not in os.environ:
        os.environ[key] = getpass.getpass(f"{key}:")

_set_env("OPENAI_API_KEY")

# Uncomment this if you have a Tavily API key and want to
# use the web search tool.
# _set_env("TAVILY_API_KEY")

运行 Redis#

对于 Colab#

将以下单元格转换为 Python 以在 Colab 中运行。

# Exit if this is not running in Colab
if [ -z "$COLAB_RELEASE_TAG" ]; then
  exit 0
fi

curl -fsSL https://packages.redis.io/gpg | sudo gpg --dearmor -o /usr/share/keyrings/redis-archive-keyring.gpg
echo "deb [signed-by=/usr/share/keyrings/redis-archive-keyring.gpg] https://packages.redis.io/deb $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/redis.list
sudo apt-get update  > /dev/null 2>&1
sudo apt-get install redis-stack-server  > /dev/null 2>&1
redis-stack-server --daemonize yes

对于其他环境#

有多种方法可以运行必要的 redis-stack 实例

  1. 1.在云端,部署一个免费的云端 Redis 实例。或者,如果您有自己的 Redis 软件版本在运行,那也可以!
  2. 2.按操作系统,参阅文档

使用 docker:docker run -d --name redis-stack-server -p 6379:6379 redis/redis-stack-server:latest

测试 Redis 连接#

import os
from redis import Redis

# Use the environment variable if set, otherwise default to localhost
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379")

redis_client = Redis.from_url(REDIS_URL)
redis_client.ping()

短期记忆与长期记忆#

智能体使用短期记忆长期记忆。短期记忆和长期记忆的实现方式不同,智能体使用它们的方式也不同。让我们深入了解细节。我们很快会回到代码。

短期记忆

对于短期记忆,智能体使用 Redis 跟踪对话历史。由于这是一个 LangGraph 智能体,我们使用 RedisSaver 类来实现这一点。RedisSaver 是 LangGraph 所称的。您可以在LangGraph 文档中阅读有关 checkpointers 的更多信息。简而言之,它们为图中的每个节点存储状态,对于此智能体而言,这包括对话历史。

这里有一张图,展示了智能体如何使用 Redis 进行短期记忆。图中的每个节点(检索用户、响应、总结对话)都会将其“状态”持久化到 Redis。状态对象包含智能体针对当前线程的消息对话历史。

如果 Redis 持久化开启,Redis 会将短期记忆持久化到磁盘。这意味着如果您退出智能体并使用相同的线程 ID 和用户 ID 返回,您将恢复同一段对话。

对话历史可能会变得很长,并污染 LLM 的上下文窗口。为了管理这一点,在对话的每个“回合”之后,当对话超过可配置的阈值时,智能体会总结消息。Checkpointers 默认不会这样做,所以我们在图中创建了一个用于总结的节点。

注意:我们将在本笔记本稍后看到总结节点的示例代码。

长期记忆

除了对话历史之外,智能体还使用 RedisVL 将长期记忆存储在 Redis 的搜索索引中。这里有一张图,展示了涉及的组件

智能体跟踪两种类型的长期记忆

  • 情景记忆 (Episodic):用户特定的经历和偏好
  • 语义记忆 (Semantic):关于旅行目的地和要求的通用知识

注意 如果您熟悉 CoALA 论文,这里的术语“情景记忆”和“语义记忆”与论文中的概念相同。CoALA 讨论了第三种类型的记忆,。在我们的示例中,我们将智能体代码库中用 Python 编码的逻辑视为其程序记忆。

在 Python 中表示长期记忆

我们使用几个 Pydantic 模型来表示长期记忆,包括存储到 Redis 之前和之后的状态

from datetime import datetime
from enum import Enum
from typing import List, Optional

from pydantic import BaseModel, Field
import ulid

class MemoryType(str, Enum):
    """
    The type of a long-term memory.

    EPISODIC: User specific experiences and preferences

    SEMANTIC: General knowledge on top of the user's preferences and LLM's
    training data.
    """

    EPISODIC = "episodic"
    SEMANTIC = "semantic"

class Memory(BaseModel):
    """Represents a single long-term memory."""

    content: str
    memory_type: MemoryType
    metadata: str
    
class Memories(BaseModel):
    """
    A list of memories extracted from a conversation by an LLM.

    NOTE: OpenAI's structured output requires us to wrap the list in an object.
    """
    memories: List[Memory]

class StoredMemory(Memory):
    """A stored long-term memory"""

    id: str  # The redis key
    memory_id: ulid.ULID = Field(default_factory=lambda: ulid.ULID())
    created_at: datetime = Field(default_factory=datetime.now)
    user_id: Optional[str] = None
    thread_id: Optional[str] = None
    memory_type: Optional[MemoryType] = None
    
class MemoryStrategy(str, Enum):
    """
    Supported strategies for managing long-term memory.
    
    This notebook supports two strategies for working with long-term memory:

    TOOLS: The LLM decides when to store and retrieve long-term memories, using
    tools (AKA, function-calling) to do so.

    MANUAL: The agent manually retrieves long-term memories relevant to the
    current conversation before sending every message and analyzes every
    response to extract memories to store.

    NOTE: In both cases, the agent runs a background thread to consolidate
    memories, and a workflow step to summarize conversations after the history
    grows past a threshold.
    """

    TOOLS = "tools"
    MANUAL = "manual"
    
# By default, we'll use the manual strategy
memory_strategy = MemoryStrategy.MANUAL

我们很快会回到这些模型,看看它们如何工作。

短期记忆的存储和检索#

RedisSaver 类为我们处理了短期记忆存储的基础工作,所以我们这里无需做任何事情。

长期记忆的存储和检索#

我们使用 RedisVL 存储和检索带有向量嵌入的长期记忆。这使得对过去的经历和知识进行语义搜索成为可能。

让我们设置一个新的搜索索引来存储和查询记忆

from redisvl.index import SearchIndex
from redisvl.schema.schema import IndexSchema

# Define schema for long-term memory index
memory_schema = IndexSchema.from_dict({
        "index": {
            "name": "agent_memories",
            "prefix": "memory:",
            "key_separator": ":",
            "storage_type": "json",
        },
        "fields": [
            {"name": "content", "type": "text"},
            {"name": "memory_type", "type": "tag"},
            {"name": "metadata", "type": "text"},
            {"name": "created_at", "type": "text"},
            {"name": "user_id", "type": "tag"},
            {"name": "memory_id", "type": "tag"},
            {
                "name": "embedding",
                "type": "vector",
                "attrs": {
                    "algorithm": "flat",
                    "dims": 1536,  # OpenAI embedding dimension
                    "distance_metric": "cosine",
                    "datatype": "float32",
                },
            },
        ],
    }
)

# Create search index
try:
    long_term_memory_index = SearchIndex(
        schema=memory_schema, redis_client=redis_client, overwrite=True
    )
    long_term_memory_index.create()
    print("Long-term memory index ready")
except Exception as e:
    print(f"Error creating index: {e}")

存储和检索函数

现在我们在 Redis 中有了搜索索引,我们可以编写函数来存储和检索记忆。我们可以使用 RedisVL 来编写这些函数。

首先,我们将编写一个实用函数来检查索引中是否存在与给定记忆相似的记忆。稍后,我们可以使用此函数避免存储重复的记忆。

检查是否存在相似记忆

import logging

from redisvl.query import VectorRangeQuery
from redisvl.query.filter import Tag
from redisvl.utils.vectorize.text.openai import OpenAITextVectorizer


logger = logging.getLogger(__name__)

# If we have any memories that aren't associated with a user, we'll use this ID.
SYSTEM_USER_ID = "system"

openai_embed = OpenAITextVectorizer(model="text-embedding-ada-002")

# Change this to MemoryStrategy.TOOLS to use function-calling to store and
# retrieve memories.
memory_strategy = MemoryStrategy.MANUAL


def similar_memory_exists(
    content: str,
    memory_type: MemoryType,
    user_id: str = SYSTEM_USER_ID,
    thread_id: Optional[str] = None,
    distance_threshold: float = 0.1,
) -> bool:
    """Check if a similar long-term memory already exists in Redis."""
    query_embedding = openai_embed.embed(content)
    filters = (Tag("user_id") == user_id) & (Tag("memory_type") == memory_type)
    if thread_id:
        filters = filters & (Tag("thread_id") == thread_id)

    # Search for similar memories
    vector_query = VectorRangeQuery(
        vector=query_embedding,
        num_results=1,
        vector_field_name="embedding",
        filter_expression=filters,
        distance_threshold=distance_threshold,
        return_fields=["id"],
    )
    results = long_term_memory_index.query(vector_query)
    logger.debug(f"Similar memory search results: {results}")

    if results:
        logger.debug(
            f"{len(results)} similar {'memory' if results.count == 1 else 'memories'} found. First: "
            f"{results[0]['id']}. Skipping storage."
        )
        return True

    return False

存储和检索长期记忆

我们在存储记忆时将使用 similar_memory_exists() 函数

from datetime import datetime
from typing import List, Optional, Union

import ulid


def store_memory(
    content: str,
    memory_type: MemoryType,
    user_id: str = SYSTEM_USER_ID,
    thread_id: Optional[str] = None,
    metadata: Optional[str] = None,
):
    """Store a long-term memory in Redis, avoiding duplicates."""
    if metadata is None:
        metadata = "{}"

    logger.info(f"Preparing to store memory: {content}")

    if similar_memory_exists(content, memory_type, user_id, thread_id):
        logger.info("Similar memory found, skipping storage")
        return

    embedding = openai_embed.embed(content)

    memory_data = {
        "user_id": user_id or SYSTEM_USER_ID,
        "content": content,
        "memory_type": memory_type.value,
        "metadata": metadata,
        "created_at": datetime.now().isoformat(),
        "embedding": embedding,
        "memory_id": str(ulid.ULID()),
        "thread_id": thread_id,
    }

    try:
        long_term_memory_index.load([memory_data])
    except Exception as e:
        logger.error(f"Error storing memory: {e}")
        return

    logger.info(f"Stored {memory_type} memory: {content}")

既然我们正在存储记忆,我们就可以检索它们了

def retrieve_memories(
    query: str,
    memory_type: Union[Optional[MemoryType], List[MemoryType]] = None,
    user_id: str = SYSTEM_USER_ID,
    thread_id: Optional[str] = None,
    distance_threshold: float = 0.1,
    limit: int = 5,
) -> List[StoredMemory]:
    """Retrieve relevant memories from Redis"""
    # Create vector query
    logger.debug(f"Retrieving memories for query: {query}")
    vector_query = VectorRangeQuery(
        vector=openai_embed.embed(query),
        return_fields=[
            "content",
            "memory_type",
            "metadata",
            "created_at",
            "memory_id",
            "thread_id",
            "user_id",
        ],
        num_results=limit,
        vector_field_name="embedding",
        dialect=2,
        distance_threshold=distance_threshold,
    )

    base_filters = [f"@user_id:{{{user_id or SYSTEM_USER_ID}}}"]

    if memory_type:
        if isinstance(memory_type, list):
            base_filters.append(f"@memory_type:{{{'|'.join(memory_type)}}}")
        else:
            base_filters.append(f"@memory_type:{{{memory_type.value}}}")

    if thread_id:
        base_filters.append(f"@thread_id:{{{thread_id}}}")

    vector_query.set_filter(" ".join(base_filters))

    # Execute search
    results = long_term_memory_index.query(vector_query)

    # Parse results
    memories = []
    for doc in results:
        try:
            memory = StoredMemory(
                id=doc["id"],
                memory_id=doc["memory_id"],
                user_id=doc["user_id"],
                thread_id=doc.get("thread_id", None),
                memory_type=MemoryType(doc["memory_type"]),
                content=doc["content"],
                created_at=doc["created_at"],
                metadata=doc["metadata"],
            )
            memories.append(memory)
        except Exception as e:
            logger.error(f"Error parsing memory: {e}")
            continue
    return memories

手动管理长期记忆与调用工具的对比#

在进行 LLM 查询时,智能体可以通过以下两种方式(以及更多方式,但我们将讨论这两种)存储和检索相关的长期记忆

  1. 1.将记忆检索和存储暴露为 LLM 可以根据上下文决定调用的“工具”。
  2. 2.手动使用相关记忆增强提示,并手动提取和存储相关记忆。

这两种方法都有权衡。

工具调用将是否存储记忆或查找相关记忆的决定留给了 LLM。这会增加请求的延迟。它通常会导致对 Redis 的调用次数减少,但有时也会错过检索潜在相关上下文以及/或从对话中提取相关记忆。

手动记忆管理会导致对 Redis 的调用次数增加,但会减少 LLM 的往返请求,从而降低延迟。手动提取记忆通常会比工具调用提取更多记忆,这会在 Redis 中存储更多数据,并应导致向 LLM 请求添加更多上下文。更多上下文意味着更高的上下文感知能力,但也会增加 token 消耗。

您可以通过更改 memory_strategy 变量来测试这两种方法。

手动管理记忆#

使用手动记忆管理策略,我们将在用户与智能体每次交互后提取记忆。然后,在发送查询之前,我们将在未来的交互中检索这些记忆。

提取记忆

我们将在每次交互后手动调用 extract_memories 函数

from langchain_core.messages import HumanMessage
from langchain_core.runnables.config import RunnableConfig
from langchain_openai import ChatOpenAI
from langgraph.graph.message import MessagesState

class RuntimeState(MessagesState):
    """Agent state (just messages for now)"""

    pass

memory_llm = ChatOpenAI(model="gpt-4o", temperature=0.3).with_structured_output(
    Memories
)

def extract_memories(
    last_processed_message_id: Optional[str],
    state: RuntimeState,
    config: RunnableConfig,
) -> Optional[str]:
    """Extract and store memories in long-term memory"""
    logger.debug(f"Last message ID is: {last_processed_message_id}")

    if len(state["messages"]) < 3:  # Need at least a user message and agent response
        logger.debug("Not enough messages to extract memories")
        return last_processed_message_id

    user_id = config.get("configurable", {}).get("user_id", None)
    if not user_id:
        logger.warning("No user ID found in config when extracting memories")
        return last_processed_message_id

    # Get the messages
    messages = state["messages"]

    # Find the newest message ID (or None if no IDs)
    newest_message_id = None
    for msg in reversed(messages):
        if hasattr(msg, "id") and msg.id:
            newest_message_id = msg.id
            break

    logger.debug(f"Newest message ID is: {newest_message_id}")

    # If we've already processed up to this message ID, skip
    if (
        last_processed_message_id
        and newest_message_id
        and last_processed_message_id == newest_message_id
    ):
        logger.debug(f"Already processed messages up to ID {newest_message_id}")
        return last_processed_message_id

    # Find the index of the message with last_processed_message_id
    start_index = 0
    if last_processed_message_id:
        for i, msg in enumerate(messages):
            if hasattr(msg, "id") and msg.id == last_processed_message_id:
                start_index = i + 1  # Start processing from the next message
                break

    # Check if there are messages to process
    if start_index >= len(messages):
        logger.debug("No new messages to process since last processed message")
        return newest_message_id

    # Get only the messages after the last processed message
    messages_to_process = messages[start_index:]

    # If there are not enough messages to process, include some context
    if len(messages_to_process) < 3 and start_index > 0:
        # Include up to 3 messages before the start_index for context
        context_start = max(0, start_index - 3)
        messages_to_process = messages[context_start:]

    # Format messages for the memory agent
    message_history = "\n".join(
        [
            f"{'User' if isinstance(msg, HumanMessage) else 'Assistant'}: {msg.content}"
            for msg in messages_to_process
        ]
    )

    prompt = f"""
    You are a long-memory manager. Your job is to analyze this message history
    and extract information that might be useful in future conversations.
    
    Extract two types of memories:
    1. EPISODIC: Personal experiences and preferences specific to this user
       Example: "User prefers window seats" or "User had a bad experience in Paris"
    
    2. SEMANTIC: General facts and knowledge about travel that could be useful
       Example: "The best time to visit Japan is during cherry blossom season in April"
    
    For each memory, provide:
    - Type: The memory type (EPISODIC/SEMANTIC)
    - Content: The actual information to store
    - Metadata: Relevant tags and context (as JSON)
    
    IMPORTANT RULES:
    1. Only extract information that would be genuinely useful for future interactions.
    2. Do not extract procedural knowledge - that is handled by the system's built-in tools and prompts.
    3. You are a large language model, not a human - do not extract facts that you already know.
    
    Message history:
    {message_history}
    
    Extracted memories:
    """

    memories_to_store: Memories = memory_llm.invoke([HumanMessage(content=prompt)])  # type: ignore

    # Store each extracted memory
    for memory_data in memories_to_store.memories:
        store_memory(
            content=memory_data.content,
            memory_type=memory_data.memory_type,
            user_id=user_id,
            metadata=memory_data.metadata,
        )
    # Return data with the newest processed message ID
    return newest_message_id

我们将在后台线程中使用此函数。我们将在手动记忆模式下启动该线程,但在工具模式下不启动,我们将它作为一个工作进程运行,该进程从一个 Queue 中拉取消息历史进行处理

import time
from queue import Queue

DEFAULT_MEMORY_WORKER_INTERVAL = 5 * 60  # 5 minutes
DEFAULT_MEMORY_WORKER_BACKOFF_INTERVAL = 10 * 60  # 10 minutes

def memory_worker(
    memory_queue: Queue,
    user_id: str,
    interval: int = DEFAULT_MEMORY_WORKER_INTERVAL,
    backoff_interval: int = DEFAULT_MEMORY_WORKER_BACKOFF_INTERVAL,
):
    """Worker function that processes long-term memory extraction requests"""
    key = f"memory_worker:{user_id}:last_processed_message_id"

    last_processed_message_id = redis_client.get(key)
    logger.debug(f"Last processed message ID: {last_processed_message_id}")
    last_processed_message_id = (
        str(last_processed_message_id) if last_processed_message_id else None
    )

    while True:
        try:
            # Get the next state and config from the queue (blocks until an item is available)
            state, config = memory_queue.get()

            # Extract long-term memories from the conversation history
            last_processed_message_id = extract_memories(
                last_processed_message_id, state, config
            )
            logger.debug(
                f"Memory worker extracted memories. Last processed message ID: {last_processed_message_id}"
            )

            if last_processed_message_id:
                logger.debug(
                    f"Setting last processed message ID: {last_processed_message_id}"
                )
                redis_client.set(key, last_processed_message_id)

            # Mark the task as done
            memory_queue.task_done()
            logger.debug("Memory extraction completed for queue item")
            # Wait before processing next item
            time.sleep(interval)
        except Exception as e:
            # Wait before processing next item after an error
            logger.exception(f"Error in memory worker thread: {e}")
            time.sleep(backoff_interval)

# NOTE: We'll actually start the worker thread later, in the main loop.

使用相关记忆增强查询#

对于用户与智能体的每一次交互,我们都会查询相关记忆,并使用 retrieve_relevant_memories() 将它们添加到 LLM 提示中。

注意:我们只在“手动”记忆管理策略中运行此节点。如果使用“工具”,LLM 将决定何时检索记忆。

def retrieve_relevant_memories(
    state: RuntimeState, config: RunnableConfig
) -> RuntimeState:
    """Retrieve relevant memories based on the current conversation."""
    if not state["messages"]:
        logger.debug("No messages in state")
        return state

    latest_message = state["messages"][-1]
    if not isinstance(latest_message, HumanMessage):
        logger.debug("Latest message is not a HumanMessage: ", latest_message)
        return state

    user_id = config.get("configurable", {}).get("user_id", SYSTEM_USER_ID)

    query = str(latest_message.content)
    relevant_memories = retrieve_memories(
        query=query,
        memory_type=[MemoryType.EPISODIC, MemoryType.SEMANTIC],
        limit=5,
        user_id=user_id,
        distance_threshold=0.3,
    )

    logger.debug(f"All relevant memories: {relevant_memories}")

    # We'll augment the latest human message with the relevant memories.
    if relevant_memories:
        memory_context = "\n\n### Relevant memories from previous conversations:\n"

        # Group by memory type
        memory_types = {
            MemoryType.EPISODIC: "User Preferences & History",
            MemoryType.SEMANTIC: "Travel Knowledge",
        }

        for mem_type, type_label in memory_types.items():
            memories_of_type = [
                m for m in relevant_memories if m.memory_type == mem_type
            ]
            if memories_of_type:
                memory_context += f"\n**{type_label}**:\n"
                for mem in memories_of_type:
                    memory_context += f"- {mem.content}\n"

        augmented_message = HumanMessage(content=f"{query}\n{memory_context}")
        state["messages"][-1] = augmented_message

        logger.debug(f"Augmented message: {augmented_message.content}")

    return state.copy()

这是我们看到的第一个代表我们将构建的 LangGraph 图中节点的函数。作为节点表示,此函数接收一个包含图运行时状态的状态对象,对话历史就存储在此状态对象中。其 config 参数包含用户和线程 ID 等数据。

这将是我们稍后组装的图的起始节点。当用户使用消息调用图时,我们首先要做的(在使用“手动”记忆策略时)是使用潜在相关的记忆增强该消息。

定义工具#

现在我们已经定义了存储函数,我们可以创建工具。稍后我们将需要这些工具来设置我们的智能体。这些工具仅在智能体以“工具”记忆管理模式运行时使用。

from langchain_core.tools import tool
from typing import Dict, Optional


@tool
def store_memory_tool(
    content: str,
    memory_type: MemoryType,
    metadata: Optional[Dict[str, str]] = None,
    config: Optional[RunnableConfig] = None,
) -> str:
    """
    Store a long-term memory in the system.

    Use this tool to save important information about user preferences,
    experiences, or general knowledge that might be useful in future
    interactions.
    """
    config = config or RunnableConfig()
    user_id = config.get("user_id", SYSTEM_USER_ID)
    thread_id = config.get("thread_id")

    try:
        # Store in long-term memory
        store_memory(
            content=content,
            memory_type=memory_type,
            user_id=user_id,
            thread_id=thread_id,
            metadata=str(metadata) if metadata else None,
        )

        return f"Successfully stored {memory_type} memory: {content}"
    except Exception as e:
        return f"Error storing memory: {str(e)}"


@tool
def retrieve_memories_tool(
    query: str,
    memory_type: List[MemoryType],
    limit: int = 5,
    config: Optional[RunnableConfig] = None,
) -> str:
    """
    Retrieve long-term memories relevant to the query.

    Use this tool to access previously stored information about user
    preferences, experiences, or general knowledge.
    """
    config = config or RunnableConfig()
    user_id = config.get("user_id", SYSTEM_USER_ID)

    try:
        # Get long-term memories
        stored_memories = retrieve_memories(
            query=query,
            memory_type=memory_type,
            user_id=user_id,
            limit=limit,
            distance_threshold=0.3,
        )

        # Format the response
        response = []

        if stored_memories:
            response.append("Long-term memories:")
            for memory in stored_memories:
                response.append(f"- [{memory.memory_type}] {memory.content}")

        return "\n".join(response) if response else "No relevant memories found."

    except Exception as e:
        return f"Error retrieving memories: {str(e)}"

创建智能体#

因为我们使用了针对不同目的配置的不同 LLM 对象以及一个预构建的 ReAct 智能体,我们需要一个节点来调用智能体并返回响应。但在调用智能体之前,我们需要先进行设置。这需要定义智能体所需的工具。

import json
from typing import Dict, List, Optional, Tuple, Union

from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.callbacks.manager import CallbackManagerForToolRun
from langchain_core.messages import AIMessage, AIMessageChunk, SystemMessage
from langgraph.prebuilt.chat_agent_executor import create_react_agent
from langgraph.checkpoint.redis import RedisSaver


class CachingTavilySearchResults(TavilySearchResults):
    """
    An interface to Tavily search that caches results in Redis.
    
    Caching the results of the web search allows us to avoid rate limiting,
    improve latency, and reduce costs.
    """

    def _run(
        self,
        query: str,
        run_manager: Optional[CallbackManagerForToolRun] = None,
    ) -> Tuple[Union[List[Dict[str, str]], str], Dict]:
        """Use the tool."""
        cache_key = f"tavily_search:{query}"
        cached_result: Optional[str] = redis_client.get(cache_key)  # type: ignore
        if cached_result:
            return json.loads(cached_result), {}
        else:
            result, raw_results = super()._run(query, run_manager)
            redis_client.set(cache_key, json.dumps(result), ex=60 * 60)
            return result, raw_results


# Create a checkpoint saver for short-term memory. This keeps track of the
# conversation history for each thread. Later, we'll continually summarize the
# conversation history to keep the context window manageable, while we also
# extract long-term memories from the conversation history to store in the
# long-term memory index.
redis_saver = RedisSaver(redis_client=redis_client)
redis_saver.setup()

# Configure an LLM for the agent with a more creative temperature.
llm = ChatOpenAI(model="gpt-4o", temperature=0.7)

# Uncomment these lines if you have a Tavily API key and want to use the web
# search tool. The agent is much more useful with this tool.
# web_search_tool = CachingTavilySearchResults(max_results=2)
# base_tools = [web_search_tool]
base_tools = []

if memory_strategy == MemoryStrategy.TOOLS:
    tools = base_tools + [store_memory_tool, retrieve_memories_tool]
elif memory_strategy == MemoryStrategy.MANUAL:
    tools = base_tools


travel_agent = create_react_agent(
    model=llm,
    tools=tools,
    checkpointer=redis_saver,  # Short-term memory: the conversation history
    prompt=SystemMessage(
        content="""
        You are a travel assistant helping users plan their trips. You remember user preferences
        and provide personalized recommendations based on past interactions.
        
        You have access to the following types of memory:
        1. Short-term memory: The current conversation thread
        2. Long-term memory: 
           - Episodic: User preferences and past trip experiences (e.g., "User prefers window seats")
           - Semantic: General knowledge about travel destinations and requirements
           
        Your procedural knowledge (how to search, book flights, etc.) is built into your tools and prompts.
        
        Always be helpful, personal, and context-aware in your responses.
        """
    ),
)

响应用户#

现在我们可以编写调用智能体并响应用户的节点了

def respond_to_user(state: RuntimeState, config: RunnableConfig) -> RuntimeState:
    """Invoke the travel agent to generate a response."""
    human_messages = [m for m in state["messages"] if isinstance(m, HumanMessage)]
    if not human_messages:
        logger.warning("No HumanMessage found in state")
        return state

    try:
        for result in travel_agent.stream(
            {"messages": state["messages"]}, config=config, stream_mode="messages"
        ):
            result_messages = result.get("messages", [])

            ai_messages = [
                m
                for m in result_messages
                if isinstance(m, AIMessage) or isinstance(m, AIMessageChunk)
            ]
            if ai_messages:
                agent_response = ai_messages[-1]
                # Append only the agent's response to the original state
                state["messages"].append(agent_response)

    except Exception as e:
        logger.error(f"Error invoking travel agent: {e}")
        agent_response = AIMessage(
            content="I'm sorry, I encountered an error processing your request."
        )
    return state

总结对话历史#

我们一直在关注长期记忆,但现在让我们回到短期记忆一会儿。使用 RedisSaver,LangGraph 将自动管理我们的消息历史。但是,消息历史会无限增长,直到它超出 LLM 的 token 上下文窗口。

为了解决这个问题,如果对话历史超过了阈值,我们将在图中添加一个节点来总结对话。

from langchain_core.messages import RemoveMessage

# An LLM configured for summarization.
summarizer = ChatOpenAI(model="gpt-4o", temperature=0.3)

# The number of messages after which we'll summarize the conversation.
MESSAGE_SUMMARIZATION_THRESHOLD = 10

def summarize_conversation(
    state: RuntimeState, config: RunnableConfig
) -> Optional[RuntimeState]:
    """
    Summarize a list of messages into a concise summary to reduce context length
    while preserving important information.
    """
    messages = state["messages"]
    current_message_count = len(messages)
    if current_message_count < MESSAGE_SUMMARIZATION_THRESHOLD:
        logger.debug(f"Not summarizing conversation: {current_message_count}")
        return state

    system_prompt = """
    You are a conversation summarizer. Create a concise summary of the previous
    conversation between a user and a travel assistant.
    
    The summary should:
    1. Highlight key topics, preferences, and decisions
    2. Include any specific trip details (destinations, dates, preferences)
    3. Note any outstanding questions or topics that need follow-up
    4. Be concise but informative
    
    Format your summary as a brief narrative paragraph.
    """

    message_content = "\n".join(
        [
            f"{'User' if isinstance(msg, HumanMessage) else 'Assistant'}: {msg.content}"
            for msg in messages
        ]
    )

    # Invoke the summarizer
    summary_messages = [
        SystemMessage(content=system_prompt),
        HumanMessage(
            content=f"Please summarize this conversation:\n\n{message_content}"
        ),
    ]

    summary_response = summarizer.invoke(summary_messages)

    logger.info(f"Summarized {len(messages)} messages into a conversation summary")

    summary_message = SystemMessage(
        content=f"""
        Summary of the conversation so far:
        
        {summary_response.content}
        
        Please continue the conversation based on this summary and the recent messages.
        """
    )
    remove_messages = [
        RemoveMessage(id=msg.id) for msg in messages if msg.id is not None
    ]

    state["messages"] = [  # type: ignore
        *remove_messages,
        summary_message,
        state["messages"][-1],
    ]

    return state.copy()

组装图#

现在是组装我们的图的时候了。

from langgraph.graph import StateGraph, END, START

workflow = StateGraph(RuntimeState)

workflow.add_node("respond", respond_to_user)
workflow.add_node("summarize_conversation", summarize_conversation)

if memory_strategy == MemoryStrategy.MANUAL:
    # In manual memory mode, we'll retrieve relevant memories before
    # responding to the user, and then augment the user's message with the
    # relevant memories.
    workflow.add_node("retrieve_memories", retrieve_relevant_memories)
    workflow.add_edge(START, "retrieve_memories")
    workflow.add_edge("retrieve_memories", "respond")
else:
    # In tool-calling mode, we'll respond to the user and let the LLM
    # decide when to retrieve and store memories, using tool calls.
    workflow.add_edge(START, "respond")

# Regardless of memory strategy, we'll summarize the conversation after
# responding to the user, to keep the context window manageable.
workflow.add_edge("respond", "summarize_conversation")
workflow.add_edge("summarize_conversation", END)

# Finally, compile the graph.
graph = workflow.compile(checkpointer=redis_saver)

在后台线程中整合记忆#

我们几乎准备好创建运行我们图的主循环了。不过,首先,让我们创建一个工作进程,它会定期使用语义搜索来整合相似的记忆。稍后,我们将在主循环中以后台线程的方式运行该工作进程。

from redisvl.query import FilterQuery

def consolidate_memories(user_id: str, batch_size: int = 10):
    """
    Periodically merge similar long-term memories for a user.
    """
    logger.info(f"Starting memory consolidation for user {user_id}")
    
    # For each memory type, consolidate separately

    for memory_type in MemoryType:
        all_memories = []

        # Get all memories of this type for the user
        of_type_for_user = (Tag("user_id") == user_id) & (
            Tag("memory_type") == memory_type
        )
        filter_query = FilterQuery(filter_expression=of_type_for_user)
        
        for batch in long_term_memory_index.paginate(filter_query, page_size=batch_size):
            all_memories.extend(batch)
            
        all_memories = long_term_memory_index.query(filter_query)
        if not all_memories:
            continue

        # Group similar memories
        processed_ids = set()
        for memory in all_memories:
            if memory["id"] in processed_ids:
                continue

            memory_embedding = memory["embedding"]
            vector_query = VectorRangeQuery(
                vector=memory_embedding,
                num_results=10,
                vector_field_name="embedding",
                filter_expression=of_type_for_user
                & (Tag("memory_id") != memory["memory_id"]),
                distance_threshold=0.1,
                return_fields=[
                    "content",
                    "metadata",
                ],
            )
            similar_memories = long_term_memory_index.query(vector_query)

            # If we found similar memories, consolidate them
            if similar_memories:
                combined_content = memory["content"]
                combined_metadata = memory["metadata"]

                if combined_metadata:
                    try:
                        combined_metadata = json.loads(combined_metadata)
                    except Exception as e:
                        logger.error(f"Error parsing metadata: {e}")
                        combined_metadata = {}

                for similar in similar_memories:
                    # Merge the content of similar memories
                    combined_content += f" {similar['content']}"

                    if similar["metadata"]:
                        try:
                            similar_metadata = json.loads(similar["metadata"])
                        except Exception as e:
                            logger.error(f"Error parsing metadata: {e}")
                        similar_metadata = {}

                    combined_metadata = {**combined_metadata, **similar_metadata}

                # Create a consolidated memory
                new_metadata = {
                    "consolidated": True,
                    "source_count": len(similar_memories) + 1,
                    **combined_metadata,
                }
                consolidated_memory = {
                    "content": summarize_memories(combined_content, memory_type),
                    "memory_type": memory_type.value,
                    "metadata": json.dumps(new_metadata),
                    "user_id": user_id,
                }

                # Delete the old memories
                delete_memory(memory["id"])
                for similar in similar_memories:
                    delete_memory(similar["id"])

                # Store the new consolidated memory
                store_memory(
                    content=consolidated_memory["content"],
                    memory_type=memory_type,
                    user_id=user_id,
                    metadata=consolidated_memory["metadata"],
                )

                logger.info(
                    f"Consolidated {len(similar_memories) + 1} memories into one"
                )

def delete_memory(memory_id: str):
    """Delete a memory from Redis"""
    try:
        result = long_term_memory_index.drop_keys([memory_id])
    except Exception as e:
        logger.error(f"Deleting memory {memory_id} failed: {e}")
    if result == 0:
        logger.debug(f"Deleting memory {memory_id} failed: memory not found")
    else:
        logger.info(f"Deleted memory {memory_id}")

def summarize_memories(combined_content: str, memory_type: MemoryType) -> str:
    """Use the LLM to create a concise summary of similar memories"""
    try:
        system_prompt = f"""
        You are a memory consolidation assistant. Your task is to create a single, 
        concise memory from these similar memory fragments. The new memory should
        be a {memory_type.value} memory.
        
        Combine the information without repetition while preserving all important details.
        """

        messages = [
            SystemMessage(content=system_prompt),
            HumanMessage(
                content=f"Consolidate these similar memories into one:\n\n{combined_content}"
            ),
        ]

        response = summarizer.invoke(messages)
        return str(response.content)
    except Exception as e:
        logger.error(f"Error summarizing memories: {e}")
        # Fall back to just using the combined content
        return combined_content


def memory_consolidation_worker(user_id: str):
    """
    Worker that periodically consolidates memories for the active user.

    NOTE: In production, this would probably use a background task framework, such
          as rq or Celery, and run on a schedule.
    """
    while True:
        try:
            consolidate_memories(user_id)
            # Run every 10 minutes
            time.sleep(10 * 60)
        except Exception as e:
            logger.exception(f"Error in memory consolidation worker: {e}")
            # If there's an error, wait an hour and try again
            time.sleep(60 * 60)

主循环#

现在我们可以将所有内容组合起来,运行主循环了。

运行此单元格应会要求您输入 OpenAI 和 Tavily 密钥,然后是用户名和线程 ID。您将进入一个循环,在其中可以输入查询,并在以下单元格下方看到智能体打印的响应。

import threading


def main(thread_id: str = "book_flight", user_id: str = "demo_user"):
    """Main interaction loop for the travel agent"""
    print("Welcome to the Travel Assistant! (Type 'exit' to quit)")

    config = RunnableConfig(configurable={"thread_id": thread_id, "user_id": user_id})
    state = RuntimeState(messages=[])

    # If we're using the manual memory strategy, we need to create a queue for
    # memory processing and start a worker thread. After every 'round' of a
    # conversation, the main loop will add the current state and config to the
    # queue for memory processing.
    if memory_strategy == MemoryStrategy.MANUAL:
        # Create a queue for memory processing
        memory_queue = Queue()

        # Start a worker thread that will process memory extraction tasks
        memory_thread = threading.Thread(
            target=memory_worker, args=(memory_queue, user_id), daemon=True
        )
        memory_thread.start()

    # We always run consolidation in the background, regardless of memory strategy.
    consolidation_thread = threading.Thread(
        target=memory_consolidation_worker, args=(user_id,), daemon=True
    )
    consolidation_thread.start()

    while True:
        user_input = input("\nYou (type 'quit' to quit): ")

        if not user_input:
            continue

        if user_input.lower() in ["exit", "quit"]:
            print("Thank you for using the Travel Assistant. Goodbye!")
            break

        state["messages"].append(HumanMessage(content=user_input))

        try:
            # Process user input through the graph
            for result in graph.stream(state, config=config, stream_mode="values"):
                state = RuntimeState(**result)

            logger.debug(f"# of messages after run: {len(state['messages'])}")

            # Find the most recent AI message, so we can print the response
            ai_messages = [m for m in state["messages"] if isinstance(m, AIMessage)]
            if ai_messages:
                message = ai_messages[-1].content
            else:
                logger.error("No AI messages after run")
                message = "I'm sorry, I couldn't process your request properly."
                # Add the error message to the state
                state["messages"].append(AIMessage(content=message))

            print(f"\nAssistant: {message}")

            # Add the current state to the memory processing queue
            if memory_strategy == MemoryStrategy.MANUAL:
                memory_queue.put((state.copy(), config))

        except Exception as e:
            logger.exception(f"Error processing request: {e}")
            error_message = "I'm sorry, I encountered an error processing your request."
            print(f"\nAssistant: {error_message}")
            # Add the error message to the state
            state["messages"].append(AIMessage(content=error_message))

try:
    user_id = input("Enter a user ID: ") or "demo_user"
    thread_id = input("Enter a thread ID: ") or "demo_thread"
except Exception:
    # If we're running in CI, we don't have a terminal to input from, so just exit
    exit()
else:
    main(thread_id, user_id)

总结。让我们开始构建#

想构建您自己的智能体?试试LangGraph 快速入门。然后添加我们的Redis checkpointer,为您的智能体提供快速、持久的记忆。

使用 Redis 管理 AI 智能体的记忆,您可以构建一个灵活、可扩展的系统,能够快速存储和检索记忆。查阅下面的资源,立即开始使用 Redis 进行构建,或联系我们的团队讨论 AI 智能体。 

  • Redis AI 资源:GitHub 仓库,包含代码示例和笔记本,帮助您构建 AI 应用。 
  • Redis AI 文档:快速入门和教程,帮助您快速启动和运行。

Redis Cloud:部署 Redis 最简单的方法——在 AWS、Azure 或 GCP 上免费试用。