LangGraph API

状态图工作流与循环编排

概述

LangGraph 是 LangChain 的状态图编排库,用于构建有状态、多参与者的循环应用,支持复杂的工作流和条件路由。

graph TD
    A[StateGraph] --> B[add_node]
    A --> C[add_edge]
    A --> D[add_conditional_edges]
    A --> E[set_entry_point]
    A --> F[set_finish_point]

    C --> G[普通边]
    D --> H[条件边]
    E --> I[起始点]
    F --> J[结束点]

    A --> K[compile]
    K --> L[CompiledGraph]
    L --> M[invoke]
    L --> N[stream]
    L --> O[batch]

    style A fill:#e1f5fe
    style L fill:#c8e6c9

StateGraph

StateGraph

状态图核心类。

from langgraph.graph import StateGraph

class StateGraph:
    """状态图类"""

    def __init__(
        self,
        schema: Type[TypedDict] | Dict[str, Any],
        *,
        config_schema: Optional[Type[TypedDict]] = None,
    ):
        """
        初始化状态图

        Args:
            schema: 状态模式(TypedDict 或字典)
            config_schema: 可选的配置模式
        """

核心方法

def add_node(
    self,
    name: str,
    func: Callable,
    *,
    metadata: Optional[Dict[str, Any]] = None,
) -> None:
    """
    添加节点

    Args:
        name: 节点名称
        func: 节点函数,接收 state,返回更新
        metadata: 可选的元数据
    """

def add_edge(
    self,
    start_key: str,
    end_key: str,
) -> None:
    """
    添加普通边(无条件转移)

    Args:
        start_key: 起始节点
        end_key: 目标节点
    """

def add_conditional_edges(
    self,
    source: str,
    path: Callable | Mapping[str, str],
    path_map: Optional[Mapping[str, str]] = None,
) -> None:
    """
    添加条件边

    Args:
        source: 源节点
        path: 路由函数或边映射
        path_map: 可选的边名到节点映射
    """

def set_entry_point(
    self,
    entry_point: str,
) -> None:
    """
    设置入口点

    Args:
        entry_point: 起始节点名称
    """

def set_finish_point(
    self,
    finish_point: str,
) -> None:
    """
    设置结束点

    Args:
        finish_point: 结束节点名称
    """

def compile(
    self,
    *,
    checkpointer: Optional[BaseCheckpointSaver] = None,
    interrupt_before: Optional[set[str]] = None,
    interrupt_after: Optional[set[str]] = None,
    debug: bool = False,
) -> CompiledGraph:
    """
    编译状态图

    Args:
        checkpointer: 检查点保存器(用于持久化)
        interrupt_before: 在这些节点前中断
        interrupt_after: 在这些节点后中断
        debug: 是否启用调试模式

    Returns:
        编译后的图
    """

使用示例

python
from langgraph.graph import StateGraph, END
from typing import TypedDict

# 定义状态
class GraphState(TypedDict):
    messages: list[str]
    current_step: str

# 创建状态图
graph = StateGraph(GraphState)

# 定义节点函数
def node_a(state: GraphState) -> dict:
    return {"current_step": "A completed"}

def node_b(state: GraphState) -> dict:
    return {"current_step": "B completed"}

def node_c(state: GraphState) -> dict:
    return {"current_step": "C completed"}

# 添加节点
graph.add_node("a", node_a)
graph.add_node("b", node_b)
graph.add_node("c", node_c)

# 设置入口
graph.set_entry_point("a")

# 添加边
graph.add_edge("a", "b")
graph.add_edge("b", "c")
graph.add_edge("c", END)

# 编译
app = graph.compile()

节点

节点函数

节点函数定义每个节点的行为。

from typing import Annotated
from langgraph.graph import StateGraph

# ========== 基础节点 ==========
def simple_node(state: GraphState) -> dict:
    """
    基础节点函数

    Args:
        state: 当前状态

    Returns:
        状态更新字典(会被合并到现有状态)
    """
    return {"messages": ["new message"]}

# ========== 使用 reduce 操作更新 ==========
from typing import Sequence
from operator import add

class StateWithReduce(TypedDict):
    # Annotated 定义如何更新字段
    messages: Annotated[Sequence[str], add]

def reduce_node(state: StateWithReduce) -> dict:
    # 返回的列表会被追加(不是替换)
    return {"messages": ["appended message"]}

# ========== 调用 LLM 的节点 ==========
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-4o")

def llm_node(state: GraphState) -> dict:
    """调用 LLM 的节点"""
    messages = state["messages"]
    response = llm.invoke(messages)
    return {"messages": [response.content]}

# ========== 工具调用节点 ==========
from langchain_core.tools import tool

@tool
def search(query: str) -> str:
    """搜索工具"""
    return f"搜索结果: {query}"

def tools_node(state: GraphState) -> dict:
    """使用工具的节点"""
    result = search.invoke(state["query"])
    return {"output": result}

ToolNode

预构建的工具节点。

from langgraph.prebuilt import ToolNode

class ToolNode:
    """工具节点

    运行工具并处理工具调用
    """

    def __init__(
        self,
        tools: Sequence[BaseTool],
    ):
        """
        初始化

        Args:
            tools: 工具列表
        """

    def __call__(
        self,
        state: GraphState,
        config: RunnableConfig,
    ) -> dict:
        """
        执行工具调用

        Args:
            state: 包含 tool_calls 的状态
            config: 运行配置

        Returns:
            包含工具消息的状态更新
        """

使用示例

python
from langgraph.prebuilt import ToolNode
from langchain_core.tools import tool

@tool
def get_weather(city: str) -> str:
    return f"{city} 晴天,25°C"

@tool
def get_time() -> str:
    from datetime import datetime
    return datetime.now().strftime("%H:%M:%S")

# 创建工具节点
tools = [get_weather, get_time]
tool_node = ToolNode(tools)

# 添加到图中
graph.add_node("tools", tool_node)

普通边

无条件转移的边。

from langgraph.graph import StateGraph, END

graph = StateGraph(GraphState)

# A -> B: 直接连接
graph.add_edge("node_a", "node_b")

# B -> END: 结束
graph.add_edge("node_b", END)

# 多节点链
graph.add_edge("start", "process")
graph.add_edge("process", "validate")
graph.add_edge("validate", "output")
graph.add_edge("output", END)

条件边

根据状态决定下一步的边。

def route_function(
    state: GraphState,
) -> str:
    """
    路由函数

    Args:
        state: 当前状态

    Returns:
        下一个节点的名称
    """
    if state.get("should_continue"):
        return "continue"
    else:
        return "end"

# 使用函数路由
graph.add_conditional_edges(
    "decision_node",
    route_function,
    {
        "continue": "process_node",
        "end": END
    }
)

# ========== 使用字典路由 ==========
# 简化写法:状态字段直接映射到节点
graph.add_conditional_edges(
    "decision_node",
    {
        "approve": "approved_node",
        "reject": "rejected_node",
        "review": "review_node"
    }
)

使用示例

python
# ========== 示例1: 简单流程图 ==========
from langgraph.graph import StateGraph, END
from typing import TypedDict

class ProcessState(TypedDict):
    input: str
    processed: bool
    output: str

def process_node(state: ProcessState) -> dict:
    return {"processed": True, "output": state["input"].upper()}

graph = StateGraph(ProcessState)
graph.add_node("process", process_node)
graph.set_entry_point("process")
graph.add_edge("process", END)

app = graph.compile()
result = app.invoke({"input": "hello"})
# {"input": "hello", "processed": True, "output": "HELLO"}

# ========== 示例2: 条件路由 ==========
class RoutingState(TypedDict):
    score: float
    result: str

def evaluate(state: RoutingState) -> dict:
    score = len(state["result"])  # 简单评分
    return {"score": score}

def route_logic(state: RoutingState) -> str:
    if state["score"] > 10:
        return "high"
    elif state["score"] > 5:
        return "medium"
    return "low"

def high_handler(state: RoutingState) -> dict:
    return {"result": "高分处理"}

def medium_handler(state: RoutingState) -> dict:
    return {"result": "中等分处理"}

def low_handler(state: RoutingState) -> dict:
    return {"result": "低分处理"}

graph = StateGraph(RoutingState)
graph.add_node("evaluate", evaluate)
graph.add_node("high", high_handler)
graph.add_node("medium", medium_handler)
graph.add_node("low", low_handler)

graph.set_entry_point("evaluate")
graph.add_conditional_edges(
    "evaluate",
    route_logic,
    {"high": "high", "medium": "medium", "low": "low"}
)
graph.add_edge("high", END)
graph.add_edge("medium", END)
graph.add_edge("low", END)

# ========== 示例3: 循环图 ==========
class LoopState(TypedDict):
    count: int
    max: int
    result: str

def increment(state: LoopState) -> dict:
    new_count = state["count"] + 1
    return {"count": new_count}

def should_continue(state: LoopState) -> str:
    if state["count"] < state["max"]:
        return "continue"
    return "end"

graph = StateGraph(LoopState)
graph.add_node("increment", increment)
graph.set_entry_point("increment")

graph.add_conditional_edges(
    "increment",
    should_continue,
    {"continue": "increment", "end": END}
)

# 循环 5 次
app = graph.compile()
result = app.invoke({"count": 0, "max": 5})
# {"count": 5, "max": 5, "result": ""}

# ========== 示例4: Agent 循环 ==========
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import ToolNode
from langchain_core.tools import tool

@tool
def search(query: str) -> str:
    return "搜索结果"

class AgentState(TypedDict):
    messages: list

llm = ChatOpenAI(model="gpt-4o").bind_tools([search])

def agent_node(state: AgentState) -> dict:
    response = llm.invoke(state["messages"])
    return {"messages": [response]}

def should_continue(state: AgentState) -> str:
    last_message = state["messages"][-1]
    if hasattr(last_message, "tool_calls") and last_message.tool_calls:
        return "tools"
    return END

graph = StateGraph(AgentState)
graph.add_node("agent", agent_node)
graph.add_node("tools", ToolNode([search]))

graph.set_entry_point("agent")
graph.add_conditional_edges("agent", should_continue)
graph.add_edge("tools", "agent")

# ========== 示例5: 带检查点的图 ==========
from langgraph.checkpoint.memory import MemorySaver

class State(TypedDict):
    step: int
    data: str

def step_node(state: State) -> dict:
    return {"step": state["step"] + 1}

graph = StateGraph(State)
graph.add_node("step", step_node)
graph.set_entry_point("step")
graph.add_edge("step", END)

# 添加检查点
checkpointer = MemorySaver()
app = graph.compile(checkpointer=checkpointer)

# 使用线程 ID
config = {"configurable": {"thread_id": "session-1"}}
result = app.invoke({"step": 0, "data": "test"}, config)

# 恢复状态
state = app.get_state(config)
print(state.values)  # {"step": 1, "data": "test"}

# ========== 示例6: 子图(嵌套图)=========
# 创建子图
def subgraph_builder():
    sub_graph = StateGraph(SubState)
    sub_graph.add_node("sub_a", sub_node_a)
    sub_graph.add_node("sub_b", sub_node_b)
    sub_graph.set_entry_point("sub_a")
    sub_graph.add_edge("sub_a", "sub_b")
    sub_graph.add_edge("sub_b", END)
    return sub_graph.compile()

# 添加到主图
graph.add_node("subgraph", subgraph_builder())
graph.add_edge("start", "subgraph")
graph.add_edge("subgraph", END)

# ========== 示例7: 人工审核 ==========
class HumanState(TypedDict):
    input: str
    approved: bool

def prepare(state: HumanState) -> dict:
    return {"input": state["input"]}

def human_review(state: HumanState) -> dict:
    # 等待人工输入
    print(f"审核: {state['input']}")
    choice = input("批准? (y/n): ")
    return {"approved": choice.lower() == "y"}

def final_output(state: HumanState) -> dict:
    return {"result": "已批准" if state["approved"] else "已拒绝"}

graph = StateGraph(HumanState)
graph.add_node("prepare", prepare)
graph.add_node("human_review", human_review)
graph.add_node("output", final_output)

graph.set_entry_point("prepare")
graph.add_edge("prepare", "human_review")
graph.add_edge("human_review", "output")
graph.add_edge("output", END)

# 中断等待人工输入
app = graph.compile(interrupt_before=["human_review"])

相关 API