class ReactAgent[StateT: AgentState]:
def __init__(
self,
state: StateT | None = None,
context_manager: BaseContextManager[StateT] | None = None,
publisher: BasePublisher | None = None,
id_generator: BaseIDGenerator = DefaultIDGenerator(),
container: InjectQ | None = None,
):
self._graph = StateGraph[StateT](
state=state,
context_manager=context_manager,
publisher=publisher,
id_generator=id_generator,
container=container,
)
def compile(
self,
main_node: tuple[Callable, str] | Callable,
tool_node: tuple[Callable, str] | Callable,
checkpointer: BaseCheckpointer[StateT] | None = None,
store: BaseStore | None = None,
interrupt_before: list[str] | None = None,
interrupt_after: list[str] | None = None,
callback_manager: CallbackManager = CallbackManager(),
) -> CompiledGraph:
# Determine main node function and name
if isinstance(main_node, tuple):
main_func, main_name = main_node
if not callable(main_func):
raise ValueError("main_node[0] must be a callable function")
else:
main_func = main_node
main_name = "MAIN"
if not callable(main_func):
raise ValueError("main_node must be a callable function")
# Determine tool node function and name
if isinstance(tool_node, tuple):
tool_func, tool_name = tool_node
# Accept both callable functions and ToolNode instances
if not callable(tool_func) and not hasattr(tool_func, "invoke"):
raise ValueError("tool_node[0] must be a callable function or ToolNode")
else:
tool_func = tool_node
tool_name = "TOOL"
# Accept both callable functions and ToolNode instances
# ToolNode instances have an 'invoke' method but are not callable
if not callable(tool_func) and not hasattr(tool_func, "invoke"):
raise ValueError("tool_node must be a callable function or ToolNode instance")
self._graph.add_node(main_name, main_func)
self._graph.add_node(tool_name, tool_func)
# Now create edges
self._graph.add_conditional_edges(
main_name,
_should_use_tools,
{tool_name: tool_name, END: END},
)
self._graph.add_edge(tool_name, main_name)
self._graph.set_entry_point(main_name)
return self._graph.compile(
checkpointer=checkpointer,
store=store,
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
callback_manager=callback_manager,
)