Skip to content

PyAgenity: A lightweight Python framework for building intelligent agents and multi-agent workflows.

Modules:

Name Description
adapters

Integration adapters for optional third-party SDKs.

checkpointer

Checkpointer adapters for agent state persistence in PyAgenity.

exceptions

Custom exception classes for graph operations in PyAgenity.

graph

PyAgenity Graph Module - Core Workflow Engine.

prebuilt
publisher

Publisher module for PyAgenity events.

state

State management for PyAgenity agent graphs.

store
utils

Unified utility exports for PyAgenity agent graphs.

Modules

adapters

Integration adapters for optional third-party SDKs.

This package provides unified wrappers and converters for integrating external tool registries, LLM SDKs, and other third-party services with PyAgenity agent graphs. Adapters expose registry-based discovery, function-calling schemas, and normalized execution for supported providers.

Modules:

Name Description
llm

Integration adapters for optional third-party LLM SDKs.

tools

Integration adapters for optional third-party SDKs.

Modules

llm

Integration adapters for optional third-party LLM SDKs.

This module exposes universal converter APIs to normalize responses and streaming outputs from popular LLM SDKs (e.g., LiteLLM, OpenAI) for use in PyAgenity agent graphs. Adapters provide unified conversion, streaming, and message normalization for agent workflows.

Exports

BaseConverter: Abstract base class for LLM response converters. ConverterType: Enum of supported converter types. LiteLLMConverter: Converter for LiteLLM responses and streams.

OpenAIConverter: (commented, available if implemented)

Modules:

Name Description
base_converter
litellm_converter
model_response_converter

Classes:

Name Description
BaseConverter

Abstract base class for all LLM response converters.

ConverterType

Enumeration of supported converter types for LLM responses.

LiteLLMConverter

Converter for LiteLLM responses to PyAgenity Message format.

Attributes
__all__ module-attribute
__all__ = ['BaseConverter', 'ConverterType', 'LiteLLMConverter']
Classes
BaseConverter

Bases: ABC

Abstract base class for all LLM response converters.

Subclasses should implement methods to convert standard and streaming LLM responses into PyAgenity's internal message/event formats.

Attributes:

Name Type Description
state AgentState | None

Optional agent state for context during conversion.

Methods:

Name Description
__init__

Initialize the converter.

convert_response

Convert a standard agent response to a Message.

convert_streaming_response

Convert a streaming agent response to an async generator of EventModel or Message.

Source code in pyagenity/adapters/llm/base_converter.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
class BaseConverter(ABC):
    """
    Abstract base class for all LLM response converters.

    Subclasses should implement methods to convert standard and streaming
    LLM responses into PyAgenity's internal message/event formats.

    Attributes:
        state (AgentState | None): Optional agent state for context during conversion.
    """

    def __init__(self, state: AgentState | None = None) -> None:
        """
        Initialize the converter.

        Args:
            state (AgentState | None): Optional agent state for context during conversion.
        """
        self.state = state

    @abstractmethod
    async def convert_response(self, response: Any) -> Message:
        """
        Convert a standard agent response to a Message.

        Args:
            response (Any): The raw response from the LLM or agent.

        Returns:
            Message: The converted message object.

        Raises:
            NotImplementedError: If not implemented in subclass.
        """
        raise NotImplementedError("Conversion not implemented for this converter")

    @abstractmethod
    async def convert_streaming_response(
        self,
        config: dict,
        node_name: str,
        response: Any,
        meta: dict | None = None,
    ) -> AsyncGenerator[EventModel | Message, None]:
        """
        Convert a streaming agent response to an async generator of EventModel or Message.

        Args:
            config (dict): Node configuration parameters.
            node_name (str): Name of the node processing the response.
            response (Any): The raw streaming response from the LLM or agent.
            meta (dict | None): Optional metadata for conversion.

        Yields:
            EventModel | Message: Chunks of the converted streaming response.

        Raises:
            NotImplementedError: If not implemented in subclass.
        """
        raise NotImplementedError("Streaming not implemented for this converter")
Attributes
state instance-attribute
state = state
Functions
__init__
__init__(state=None)

Initialize the converter.

Parameters:

Name Type Description Default
state AgentState | None

Optional agent state for context during conversion.

None
Source code in pyagenity/adapters/llm/base_converter.py
32
33
34
35
36
37
38
39
def __init__(self, state: AgentState | None = None) -> None:
    """
    Initialize the converter.

    Args:
        state (AgentState | None): Optional agent state for context during conversion.
    """
    self.state = state
convert_response abstractmethod async
convert_response(response)

Convert a standard agent response to a Message.

Parameters:

Name Type Description Default
response Any

The raw response from the LLM or agent.

required

Returns:

Name Type Description
Message Message

The converted message object.

Raises:

Type Description
NotImplementedError

If not implemented in subclass.

Source code in pyagenity/adapters/llm/base_converter.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
@abstractmethod
async def convert_response(self, response: Any) -> Message:
    """
    Convert a standard agent response to a Message.

    Args:
        response (Any): The raw response from the LLM or agent.

    Returns:
        Message: The converted message object.

    Raises:
        NotImplementedError: If not implemented in subclass.
    """
    raise NotImplementedError("Conversion not implemented for this converter")
convert_streaming_response abstractmethod async
convert_streaming_response(config, node_name, response, meta=None)

Convert a streaming agent response to an async generator of EventModel or Message.

Parameters:

Name Type Description Default
config dict

Node configuration parameters.

required
node_name str

Name of the node processing the response.

required
response Any

The raw streaming response from the LLM or agent.

required
meta dict | None

Optional metadata for conversion.

None

Yields:

Type Description
AsyncGenerator[EventModel | Message, None]

EventModel | Message: Chunks of the converted streaming response.

Raises:

Type Description
NotImplementedError

If not implemented in subclass.

Source code in pyagenity/adapters/llm/base_converter.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
@abstractmethod
async def convert_streaming_response(
    self,
    config: dict,
    node_name: str,
    response: Any,
    meta: dict | None = None,
) -> AsyncGenerator[EventModel | Message, None]:
    """
    Convert a streaming agent response to an async generator of EventModel or Message.

    Args:
        config (dict): Node configuration parameters.
        node_name (str): Name of the node processing the response.
        response (Any): The raw streaming response from the LLM or agent.
        meta (dict | None): Optional metadata for conversion.

    Yields:
        EventModel | Message: Chunks of the converted streaming response.

    Raises:
        NotImplementedError: If not implemented in subclass.
    """
    raise NotImplementedError("Streaming not implemented for this converter")
ConverterType

Bases: Enum

Enumeration of supported converter types for LLM responses.

Attributes:

Name Type Description
ANTHROPIC
CUSTOM
GOOGLE
LITELLM
OPENAI
Source code in pyagenity/adapters/llm/base_converter.py
11
12
13
14
15
16
17
18
class ConverterType(Enum):
    """Enumeration of supported converter types for LLM responses."""

    OPENAI = "openai"
    LITELLM = "litellm"
    ANTHROPIC = "anthropic"
    GOOGLE = "google"
    CUSTOM = "custom"
Attributes
ANTHROPIC class-attribute instance-attribute
ANTHROPIC = 'anthropic'
CUSTOM class-attribute instance-attribute
CUSTOM = 'custom'
GOOGLE class-attribute instance-attribute
GOOGLE = 'google'
LITELLM class-attribute instance-attribute
LITELLM = 'litellm'
OPENAI class-attribute instance-attribute
OPENAI = 'openai'
LiteLLMConverter

Bases: BaseConverter

Converter for LiteLLM responses to PyAgenity Message format.

Handles both standard and streaming responses, extracting content, reasoning, tool calls, and token usage details.

Methods:

Name Description
__init__

Initialize the converter.

convert_response

Convert a LiteLLM ModelResponse to a Message.

convert_streaming_response

Convert a LiteLLM streaming or standard response to Message(s).

Attributes:

Name Type Description
state
Source code in pyagenity/adapters/llm/litellm_converter.py
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
class LiteLLMConverter(BaseConverter):
    """
    Converter for LiteLLM responses to PyAgenity Message format.

    Handles both standard and streaming responses, extracting content, reasoning,
    tool calls, and token usage details.
    """

    async def convert_response(self, response: ModelResponse) -> Message:
        """
        Convert a LiteLLM ModelResponse to a Message.

        Args:
            response (ModelResponse): The LiteLLM model response object.

        Returns:
            Message: The converted message object.

        Raises:
            ImportError: If LiteLLM is not installed.
        """
        if not HAS_LITELLM:
            raise ImportError("litellm is not installed. Please install it to use this converter.")

        data = response.model_dump()

        usages_data = data.get("usage", {})

        usages = TokenUsages(
            completion_tokens=usages_data.get("completion_tokens", 0),
            prompt_tokens=usages_data.get("prompt_tokens", 0),
            total_tokens=usages_data.get("total_tokens", 0),
            cache_creation_input_tokens=usages_data.get("cache_creation_input_tokens", 0),
            cache_read_input_tokens=usages_data.get("cache_read_input_tokens", 0),
            reasoning_tokens=usages_data.get("prompt_tokens_details", {}).get(
                "reasoning_tokens", 0
            ),
        )

        created_date = data.get("created", datetime.now())

        # Extract tool calls from response
        tools_calls = data.get("choices", [{}])[0].get("message", {}).get("tool_calls", []) or []

        logger.debug("Creating message from model response with id: %s", response.id)
        content = data.get("choices", [{}])[0].get("message", {}).get("content", "") or ""
        reasoning_content = (
            data.get("choices", [{}])[0].get("message", {}).get("reasoning_content", "") or ""
        )

        blocks = []
        if content:
            blocks.append(TextBlock(text=content))
        if reasoning_content:
            blocks.append(ReasoningBlock(summary=reasoning_content))
        final_tool_calls = []
        for tool_call in tools_calls:
            tool_id = tool_call.get("id", None)
            args = tool_call.get("function", {}).get("arguments", None)
            name = tool_call.get("function", {}).get("name", None)

            if not tool_id or not args or not name:
                continue

            blocks.append(
                ToolCallBlock(
                    name=name,
                    args=json.loads(args),
                    id=tool_id,
                )
            )

            if hasattr(tool_call, "model_dump"):
                final_tool_calls.append(tool_call.model_dump())
            else:
                final_tool_calls.append(tool_call)

        return Message(
            message_id=generate_id(response.id),
            role="assistant",
            content=blocks,
            reasoning=reasoning_content,
            timestamp=created_date,
            metadata={
                "provider": "litellm",
                "model": data.get("model", ""),
                "finish_reason": data.get("choices", [{}])[0].get("finish_reason", "UNKNOWN"),
                "object": data.get("object", ""),
                "prompt_tokens_details": usages_data.get("prompt_tokens_details", {}),
                "completion_tokens_details": usages_data.get("completion_tokens_details", {}),
            },
            usages=usages,
            raw=data,
            tools_calls=final_tool_calls if final_tool_calls else None,
        )

    def _process_chunk(
        self,
        chunk: ModelResponseStream | None,
        seq: int,
        accumulated_content: str,
        accumulated_reasoning_content: str,
        tool_calls: list,
        tool_ids: set,
    ) -> tuple[str, str, list, int, Message | None]:
        """
        Process a single chunk from a LiteLLM streaming response.

        Args:
            chunk (ModelResponseStream | None): The current chunk from the stream.
            seq (int): Sequence number of the chunk.
            accumulated_content (str): Accumulated text content so far.
            accumulated_reasoning_content (str): Accumulated reasoning content so far.
            tool_calls (list): List of tool calls detected so far.
            tool_ids (set): Set of tool call IDs to avoid duplicates.

        Returns:
            tuple: Updated accumulated content, reasoning, tool calls, sequence,
                and Message (if any).
        """
        if not chunk:
            return accumulated_content, accumulated_reasoning_content, tool_calls, seq, None

        msg: ModelResponseStream = chunk  # type: ignore
        if msg is None:
            return accumulated_content, accumulated_reasoning_content, tool_calls, seq, None
        if msg.choices is None or len(msg.choices) == 0:
            return accumulated_content, accumulated_reasoning_content, tool_calls, seq, None
        delta = msg.choices[0].delta
        if delta is None:
            return accumulated_content, accumulated_reasoning_content, tool_calls, seq, None

        # update text delta
        text_part = delta.content or ""
        content_blocks = []
        if text_part:
            content_blocks.append(TextBlock(text=text_part))
        reasoning_part = getattr(delta, "reasoning_content", "") or ""
        if reasoning_part:
            content_blocks.append(ReasoningBlock(summary=reasoning_part))
        accumulated_content += text_part
        accumulated_reasoning_content += reasoning_part
        # handle tool calls if present
        if getattr(delta, "tool_calls", None):
            for tc in delta.tool_calls:  # type: ignore[attr-defined]
                if not tc:
                    continue
                if tc.id in tool_ids:
                    continue
                tool_ids.add(tc.id)
                tool_calls.append(tc.model_dump())
                content_blocks.append(
                    ToolCallBlock(
                        name=tc.function.name,  # type: ignore
                        args=json.loads(tc.function.arguments),  # type: ignore
                        id=tc.id,  # type: ignore
                    )
                )

        output_message = Message(
            message_id=generate_id(msg.id),
            role="assistant",
            content=content_blocks,
            reasoning=accumulated_reasoning_content,
            tools_calls=tool_calls,
            delta=True,
        )

        return accumulated_content, accumulated_reasoning_content, tool_calls, seq, output_message

    async def _handle_stream(
        self,
        config: dict,
        node_name: str,
        stream: CustomStreamWrapper,
        meta: dict | None = None,
    ) -> AsyncGenerator[Message]:
        """
        Handle a LiteLLM streaming response and yield Message objects for each chunk.

        Args:
            config (dict): Node configuration parameters.
            node_name (str): Name of the node processing the response.
            stream (CustomStreamWrapper): The LiteLLM streaming response object.
            meta (dict | None): Optional metadata for conversion.

        Yields:
            Message: Converted message chunk from the stream.
        """
        accumulated_content = ""
        tool_calls = []
        tool_ids = set()
        accumulated_reasoning_content = ""
        seq = 0

        is_awaitable = inspect.isawaitable(stream)

        # Await stream if necessary
        if is_awaitable:
            stream = await stream

        # Try async iteration (acompletion)
        try:
            async for chunk in stream:
                accumulated_content, accumulated_reasoning_content, tool_calls, seq, message = (
                    self._process_chunk(
                        chunk,
                        seq,
                        accumulated_content,
                        accumulated_reasoning_content,
                        tool_calls,
                        tool_ids,
                    )
                )

                if message:
                    yield message
        except Exception:  # noqa: S110 # nosec B110
            pass

        # Try sync iteration (completion)
        try:
            for chunk in stream:
                accumulated_content, accumulated_reasoning_content, tool_calls, seq, message = (
                    self._process_chunk(
                        chunk,
                        seq,
                        accumulated_content,
                        accumulated_reasoning_content,
                        tool_calls,
                        tool_ids,
                    )
                )

                if message:
                    yield message
        except Exception:  # noqa: S110 # nosec B110
            pass

        # After streaming, yield final message
        metadata = meta or {}
        metadata["provider"] = "litellm"
        metadata["node_name"] = node_name
        metadata["thread_id"] = config.get("thread_id")

        blocks = []
        if accumulated_content:
            blocks.append(TextBlock(text=accumulated_content))
        if accumulated_reasoning_content:
            blocks.append(ReasoningBlock(summary=accumulated_reasoning_content))
        if tool_calls:
            for tc in tool_calls:
                blocks.append(
                    ToolCallBlock(
                        name=tc.get("function", {}).get("name", ""),
                        args=json.loads(tc.get("function", {}).get("arguments", "{}")),
                        id=tc.get("id", ""),
                    )
                )

        logger.debug(
            "Loop done Content: %s  Reasoning: %s Tool Calls: %s",
            accumulated_content,
            accumulated_reasoning_content,
            len(tool_calls),
        )
        message = Message(
            role="assistant",
            message_id=generate_id(None),
            content=blocks,
            delta=False,
            reasoning=accumulated_reasoning_content,
            tools_calls=tool_calls,
            metadata=metadata,
        )
        yield message

    async def convert_streaming_response(  # type: ignore
        self,
        config: dict,
        node_name: str,
        response: Any,
        meta: dict | None = None,
    ) -> AsyncGenerator[Message]:
        """
        Convert a LiteLLM streaming or standard response to Message(s).

        Args:
            config (dict): Node configuration parameters.
            node_name (str): Name of the node processing the response.
            response (Any): The LiteLLM response object (stream or standard).
            meta (dict | None): Optional metadata for conversion.

        Yields:
            Message: Converted message(s) from the response.

        Raises:
            ImportError: If LiteLLM is not installed.
            Exception: If response type is unsupported.
        """
        if not HAS_LITELLM:
            raise ImportError("litellm is not installed. Please install it to use this converter.")

        if isinstance(response, CustomStreamWrapper):  # type: ignore[possibly-unbound]
            stream = cast(CustomStreamWrapper, response)
            async for event in self._handle_stream(
                config or {},
                node_name or "",
                stream,
                meta,
            ):
                yield event
        elif isinstance(response, ModelResponse):  # type: ignore[possibly-unbound]
            message = await self.convert_response(cast(ModelResponse, response))
            yield message
        else:
            raise Exception("Unsupported response type for LiteLLMConverter")
Attributes
state instance-attribute
state = state
Functions
__init__
__init__(state=None)

Initialize the converter.

Parameters:

Name Type Description Default
state AgentState | None

Optional agent state for context during conversion.

None
Source code in pyagenity/adapters/llm/base_converter.py
32
33
34
35
36
37
38
39
def __init__(self, state: AgentState | None = None) -> None:
    """
    Initialize the converter.

    Args:
        state (AgentState | None): Optional agent state for context during conversion.
    """
    self.state = state
convert_response async
convert_response(response)

Convert a LiteLLM ModelResponse to a Message.

Parameters:

Name Type Description Default
response ModelResponse

The LiteLLM model response object.

required

Returns:

Name Type Description
Message Message

The converted message object.

Raises:

Type Description
ImportError

If LiteLLM is not installed.

Source code in pyagenity/adapters/llm/litellm_converter.py
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
async def convert_response(self, response: ModelResponse) -> Message:
    """
    Convert a LiteLLM ModelResponse to a Message.

    Args:
        response (ModelResponse): The LiteLLM model response object.

    Returns:
        Message: The converted message object.

    Raises:
        ImportError: If LiteLLM is not installed.
    """
    if not HAS_LITELLM:
        raise ImportError("litellm is not installed. Please install it to use this converter.")

    data = response.model_dump()

    usages_data = data.get("usage", {})

    usages = TokenUsages(
        completion_tokens=usages_data.get("completion_tokens", 0),
        prompt_tokens=usages_data.get("prompt_tokens", 0),
        total_tokens=usages_data.get("total_tokens", 0),
        cache_creation_input_tokens=usages_data.get("cache_creation_input_tokens", 0),
        cache_read_input_tokens=usages_data.get("cache_read_input_tokens", 0),
        reasoning_tokens=usages_data.get("prompt_tokens_details", {}).get(
            "reasoning_tokens", 0
        ),
    )

    created_date = data.get("created", datetime.now())

    # Extract tool calls from response
    tools_calls = data.get("choices", [{}])[0].get("message", {}).get("tool_calls", []) or []

    logger.debug("Creating message from model response with id: %s", response.id)
    content = data.get("choices", [{}])[0].get("message", {}).get("content", "") or ""
    reasoning_content = (
        data.get("choices", [{}])[0].get("message", {}).get("reasoning_content", "") or ""
    )

    blocks = []
    if content:
        blocks.append(TextBlock(text=content))
    if reasoning_content:
        blocks.append(ReasoningBlock(summary=reasoning_content))
    final_tool_calls = []
    for tool_call in tools_calls:
        tool_id = tool_call.get("id", None)
        args = tool_call.get("function", {}).get("arguments", None)
        name = tool_call.get("function", {}).get("name", None)

        if not tool_id or not args or not name:
            continue

        blocks.append(
            ToolCallBlock(
                name=name,
                args=json.loads(args),
                id=tool_id,
            )
        )

        if hasattr(tool_call, "model_dump"):
            final_tool_calls.append(tool_call.model_dump())
        else:
            final_tool_calls.append(tool_call)

    return Message(
        message_id=generate_id(response.id),
        role="assistant",
        content=blocks,
        reasoning=reasoning_content,
        timestamp=created_date,
        metadata={
            "provider": "litellm",
            "model": data.get("model", ""),
            "finish_reason": data.get("choices", [{}])[0].get("finish_reason", "UNKNOWN"),
            "object": data.get("object", ""),
            "prompt_tokens_details": usages_data.get("prompt_tokens_details", {}),
            "completion_tokens_details": usages_data.get("completion_tokens_details", {}),
        },
        usages=usages,
        raw=data,
        tools_calls=final_tool_calls if final_tool_calls else None,
    )
convert_streaming_response async
convert_streaming_response(config, node_name, response, meta=None)

Convert a LiteLLM streaming or standard response to Message(s).

Parameters:

Name Type Description Default
config dict

Node configuration parameters.

required
node_name str

Name of the node processing the response.

required
response Any

The LiteLLM response object (stream or standard).

required
meta dict | None

Optional metadata for conversion.

None

Yields:

Name Type Description
Message AsyncGenerator[Message]

Converted message(s) from the response.

Raises:

Type Description
ImportError

If LiteLLM is not installed.

Exception

If response type is unsupported.

Source code in pyagenity/adapters/llm/litellm_converter.py
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
async def convert_streaming_response(  # type: ignore
    self,
    config: dict,
    node_name: str,
    response: Any,
    meta: dict | None = None,
) -> AsyncGenerator[Message]:
    """
    Convert a LiteLLM streaming or standard response to Message(s).

    Args:
        config (dict): Node configuration parameters.
        node_name (str): Name of the node processing the response.
        response (Any): The LiteLLM response object (stream or standard).
        meta (dict | None): Optional metadata for conversion.

    Yields:
        Message: Converted message(s) from the response.

    Raises:
        ImportError: If LiteLLM is not installed.
        Exception: If response type is unsupported.
    """
    if not HAS_LITELLM:
        raise ImportError("litellm is not installed. Please install it to use this converter.")

    if isinstance(response, CustomStreamWrapper):  # type: ignore[possibly-unbound]
        stream = cast(CustomStreamWrapper, response)
        async for event in self._handle_stream(
            config or {},
            node_name or "",
            stream,
            meta,
        ):
            yield event
    elif isinstance(response, ModelResponse):  # type: ignore[possibly-unbound]
        message = await self.convert_response(cast(ModelResponse, response))
        yield message
    else:
        raise Exception("Unsupported response type for LiteLLMConverter")
Modules
base_converter

Classes:

Name Description
BaseConverter

Abstract base class for all LLM response converters.

ConverterType

Enumeration of supported converter types for LLM responses.

Classes
BaseConverter

Bases: ABC

Abstract base class for all LLM response converters.

Subclasses should implement methods to convert standard and streaming LLM responses into PyAgenity's internal message/event formats.

Attributes:

Name Type Description
state AgentState | None

Optional agent state for context during conversion.

Methods:

Name Description
__init__

Initialize the converter.

convert_response

Convert a standard agent response to a Message.

convert_streaming_response

Convert a streaming agent response to an async generator of EventModel or Message.

Source code in pyagenity/adapters/llm/base_converter.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
class BaseConverter(ABC):
    """
    Abstract base class for all LLM response converters.

    Subclasses should implement methods to convert standard and streaming
    LLM responses into PyAgenity's internal message/event formats.

    Attributes:
        state (AgentState | None): Optional agent state for context during conversion.
    """

    def __init__(self, state: AgentState | None = None) -> None:
        """
        Initialize the converter.

        Args:
            state (AgentState | None): Optional agent state for context during conversion.
        """
        self.state = state

    @abstractmethod
    async def convert_response(self, response: Any) -> Message:
        """
        Convert a standard agent response to a Message.

        Args:
            response (Any): The raw response from the LLM or agent.

        Returns:
            Message: The converted message object.

        Raises:
            NotImplementedError: If not implemented in subclass.
        """
        raise NotImplementedError("Conversion not implemented for this converter")

    @abstractmethod
    async def convert_streaming_response(
        self,
        config: dict,
        node_name: str,
        response: Any,
        meta: dict | None = None,
    ) -> AsyncGenerator[EventModel | Message, None]:
        """
        Convert a streaming agent response to an async generator of EventModel or Message.

        Args:
            config (dict): Node configuration parameters.
            node_name (str): Name of the node processing the response.
            response (Any): The raw streaming response from the LLM or agent.
            meta (dict | None): Optional metadata for conversion.

        Yields:
            EventModel | Message: Chunks of the converted streaming response.

        Raises:
            NotImplementedError: If not implemented in subclass.
        """
        raise NotImplementedError("Streaming not implemented for this converter")
Attributes
state instance-attribute
state = state
Functions
__init__
__init__(state=None)

Initialize the converter.

Parameters:

Name Type Description Default
state AgentState | None

Optional agent state for context during conversion.

None
Source code in pyagenity/adapters/llm/base_converter.py
32
33
34
35
36
37
38
39
def __init__(self, state: AgentState | None = None) -> None:
    """
    Initialize the converter.

    Args:
        state (AgentState | None): Optional agent state for context during conversion.
    """
    self.state = state
convert_response abstractmethod async
convert_response(response)

Convert a standard agent response to a Message.

Parameters:

Name Type Description Default
response Any

The raw response from the LLM or agent.

required

Returns:

Name Type Description
Message Message

The converted message object.

Raises:

Type Description
NotImplementedError

If not implemented in subclass.

Source code in pyagenity/adapters/llm/base_converter.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
@abstractmethod
async def convert_response(self, response: Any) -> Message:
    """
    Convert a standard agent response to a Message.

    Args:
        response (Any): The raw response from the LLM or agent.

    Returns:
        Message: The converted message object.

    Raises:
        NotImplementedError: If not implemented in subclass.
    """
    raise NotImplementedError("Conversion not implemented for this converter")
convert_streaming_response abstractmethod async
convert_streaming_response(config, node_name, response, meta=None)

Convert a streaming agent response to an async generator of EventModel or Message.

Parameters:

Name Type Description Default
config dict

Node configuration parameters.

required
node_name str

Name of the node processing the response.

required
response Any

The raw streaming response from the LLM or agent.

required
meta dict | None

Optional metadata for conversion.

None

Yields:

Type Description
AsyncGenerator[EventModel | Message, None]

EventModel | Message: Chunks of the converted streaming response.

Raises:

Type Description
NotImplementedError

If not implemented in subclass.

Source code in pyagenity/adapters/llm/base_converter.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
@abstractmethod
async def convert_streaming_response(
    self,
    config: dict,
    node_name: str,
    response: Any,
    meta: dict | None = None,
) -> AsyncGenerator[EventModel | Message, None]:
    """
    Convert a streaming agent response to an async generator of EventModel or Message.

    Args:
        config (dict): Node configuration parameters.
        node_name (str): Name of the node processing the response.
        response (Any): The raw streaming response from the LLM or agent.
        meta (dict | None): Optional metadata for conversion.

    Yields:
        EventModel | Message: Chunks of the converted streaming response.

    Raises:
        NotImplementedError: If not implemented in subclass.
    """
    raise NotImplementedError("Streaming not implemented for this converter")
ConverterType

Bases: Enum

Enumeration of supported converter types for LLM responses.

Attributes:

Name Type Description
ANTHROPIC
CUSTOM
GOOGLE
LITELLM
OPENAI
Source code in pyagenity/adapters/llm/base_converter.py
11
12
13
14
15
16
17
18
class ConverterType(Enum):
    """Enumeration of supported converter types for LLM responses."""

    OPENAI = "openai"
    LITELLM = "litellm"
    ANTHROPIC = "anthropic"
    GOOGLE = "google"
    CUSTOM = "custom"
Attributes
ANTHROPIC class-attribute instance-attribute
ANTHROPIC = 'anthropic'
CUSTOM class-attribute instance-attribute
CUSTOM = 'custom'
GOOGLE class-attribute instance-attribute
GOOGLE = 'google'
LITELLM class-attribute instance-attribute
LITELLM = 'litellm'
OPENAI class-attribute instance-attribute
OPENAI = 'openai'
litellm_converter

Classes:

Name Description
LiteLLMConverter

Converter for LiteLLM responses to PyAgenity Message format.

Attributes:

Name Type Description
HAS_LITELLM
logger
Attributes
HAS_LITELLM module-attribute
HAS_LITELLM = True
logger module-attribute
logger = getLogger(__name__)
Classes
LiteLLMConverter

Bases: BaseConverter

Converter for LiteLLM responses to PyAgenity Message format.

Handles both standard and streaming responses, extracting content, reasoning, tool calls, and token usage details.

Methods:

Name Description
__init__

Initialize the converter.

convert_response

Convert a LiteLLM ModelResponse to a Message.

convert_streaming_response

Convert a LiteLLM streaming or standard response to Message(s).

Attributes:

Name Type Description
state
Source code in pyagenity/adapters/llm/litellm_converter.py
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
class LiteLLMConverter(BaseConverter):
    """
    Converter for LiteLLM responses to PyAgenity Message format.

    Handles both standard and streaming responses, extracting content, reasoning,
    tool calls, and token usage details.
    """

    async def convert_response(self, response: ModelResponse) -> Message:
        """
        Convert a LiteLLM ModelResponse to a Message.

        Args:
            response (ModelResponse): The LiteLLM model response object.

        Returns:
            Message: The converted message object.

        Raises:
            ImportError: If LiteLLM is not installed.
        """
        if not HAS_LITELLM:
            raise ImportError("litellm is not installed. Please install it to use this converter.")

        data = response.model_dump()

        usages_data = data.get("usage", {})

        usages = TokenUsages(
            completion_tokens=usages_data.get("completion_tokens", 0),
            prompt_tokens=usages_data.get("prompt_tokens", 0),
            total_tokens=usages_data.get("total_tokens", 0),
            cache_creation_input_tokens=usages_data.get("cache_creation_input_tokens", 0),
            cache_read_input_tokens=usages_data.get("cache_read_input_tokens", 0),
            reasoning_tokens=usages_data.get("prompt_tokens_details", {}).get(
                "reasoning_tokens", 0
            ),
        )

        created_date = data.get("created", datetime.now())

        # Extract tool calls from response
        tools_calls = data.get("choices", [{}])[0].get("message", {}).get("tool_calls", []) or []

        logger.debug("Creating message from model response with id: %s", response.id)
        content = data.get("choices", [{}])[0].get("message", {}).get("content", "") or ""
        reasoning_content = (
            data.get("choices", [{}])[0].get("message", {}).get("reasoning_content", "") or ""
        )

        blocks = []
        if content:
            blocks.append(TextBlock(text=content))
        if reasoning_content:
            blocks.append(ReasoningBlock(summary=reasoning_content))
        final_tool_calls = []
        for tool_call in tools_calls:
            tool_id = tool_call.get("id", None)
            args = tool_call.get("function", {}).get("arguments", None)
            name = tool_call.get("function", {}).get("name", None)

            if not tool_id or not args or not name:
                continue

            blocks.append(
                ToolCallBlock(
                    name=name,
                    args=json.loads(args),
                    id=tool_id,
                )
            )

            if hasattr(tool_call, "model_dump"):
                final_tool_calls.append(tool_call.model_dump())
            else:
                final_tool_calls.append(tool_call)

        return Message(
            message_id=generate_id(response.id),
            role="assistant",
            content=blocks,
            reasoning=reasoning_content,
            timestamp=created_date,
            metadata={
                "provider": "litellm",
                "model": data.get("model", ""),
                "finish_reason": data.get("choices", [{}])[0].get("finish_reason", "UNKNOWN"),
                "object": data.get("object", ""),
                "prompt_tokens_details": usages_data.get("prompt_tokens_details", {}),
                "completion_tokens_details": usages_data.get("completion_tokens_details", {}),
            },
            usages=usages,
            raw=data,
            tools_calls=final_tool_calls if final_tool_calls else None,
        )

    def _process_chunk(
        self,
        chunk: ModelResponseStream | None,
        seq: int,
        accumulated_content: str,
        accumulated_reasoning_content: str,
        tool_calls: list,
        tool_ids: set,
    ) -> tuple[str, str, list, int, Message | None]:
        """
        Process a single chunk from a LiteLLM streaming response.

        Args:
            chunk (ModelResponseStream | None): The current chunk from the stream.
            seq (int): Sequence number of the chunk.
            accumulated_content (str): Accumulated text content so far.
            accumulated_reasoning_content (str): Accumulated reasoning content so far.
            tool_calls (list): List of tool calls detected so far.
            tool_ids (set): Set of tool call IDs to avoid duplicates.

        Returns:
            tuple: Updated accumulated content, reasoning, tool calls, sequence,
                and Message (if any).
        """
        if not chunk:
            return accumulated_content, accumulated_reasoning_content, tool_calls, seq, None

        msg: ModelResponseStream = chunk  # type: ignore
        if msg is None:
            return accumulated_content, accumulated_reasoning_content, tool_calls, seq, None
        if msg.choices is None or len(msg.choices) == 0:
            return accumulated_content, accumulated_reasoning_content, tool_calls, seq, None
        delta = msg.choices[0].delta
        if delta is None:
            return accumulated_content, accumulated_reasoning_content, tool_calls, seq, None

        # update text delta
        text_part = delta.content or ""
        content_blocks = []
        if text_part:
            content_blocks.append(TextBlock(text=text_part))
        reasoning_part = getattr(delta, "reasoning_content", "") or ""
        if reasoning_part:
            content_blocks.append(ReasoningBlock(summary=reasoning_part))
        accumulated_content += text_part
        accumulated_reasoning_content += reasoning_part
        # handle tool calls if present
        if getattr(delta, "tool_calls", None):
            for tc in delta.tool_calls:  # type: ignore[attr-defined]
                if not tc:
                    continue
                if tc.id in tool_ids:
                    continue
                tool_ids.add(tc.id)
                tool_calls.append(tc.model_dump())
                content_blocks.append(
                    ToolCallBlock(
                        name=tc.function.name,  # type: ignore
                        args=json.loads(tc.function.arguments),  # type: ignore
                        id=tc.id,  # type: ignore
                    )
                )

        output_message = Message(
            message_id=generate_id(msg.id),
            role="assistant",
            content=content_blocks,
            reasoning=accumulated_reasoning_content,
            tools_calls=tool_calls,
            delta=True,
        )

        return accumulated_content, accumulated_reasoning_content, tool_calls, seq, output_message

    async def _handle_stream(
        self,
        config: dict,
        node_name: str,
        stream: CustomStreamWrapper,
        meta: dict | None = None,
    ) -> AsyncGenerator[Message]:
        """
        Handle a LiteLLM streaming response and yield Message objects for each chunk.

        Args:
            config (dict): Node configuration parameters.
            node_name (str): Name of the node processing the response.
            stream (CustomStreamWrapper): The LiteLLM streaming response object.
            meta (dict | None): Optional metadata for conversion.

        Yields:
            Message: Converted message chunk from the stream.
        """
        accumulated_content = ""
        tool_calls = []
        tool_ids = set()
        accumulated_reasoning_content = ""
        seq = 0

        is_awaitable = inspect.isawaitable(stream)

        # Await stream if necessary
        if is_awaitable:
            stream = await stream

        # Try async iteration (acompletion)
        try:
            async for chunk in stream:
                accumulated_content, accumulated_reasoning_content, tool_calls, seq, message = (
                    self._process_chunk(
                        chunk,
                        seq,
                        accumulated_content,
                        accumulated_reasoning_content,
                        tool_calls,
                        tool_ids,
                    )
                )

                if message:
                    yield message
        except Exception:  # noqa: S110 # nosec B110
            pass

        # Try sync iteration (completion)
        try:
            for chunk in stream:
                accumulated_content, accumulated_reasoning_content, tool_calls, seq, message = (
                    self._process_chunk(
                        chunk,
                        seq,
                        accumulated_content,
                        accumulated_reasoning_content,
                        tool_calls,
                        tool_ids,
                    )
                )

                if message:
                    yield message
        except Exception:  # noqa: S110 # nosec B110
            pass

        # After streaming, yield final message
        metadata = meta or {}
        metadata["provider"] = "litellm"
        metadata["node_name"] = node_name
        metadata["thread_id"] = config.get("thread_id")

        blocks = []
        if accumulated_content:
            blocks.append(TextBlock(text=accumulated_content))
        if accumulated_reasoning_content:
            blocks.append(ReasoningBlock(summary=accumulated_reasoning_content))
        if tool_calls:
            for tc in tool_calls:
                blocks.append(
                    ToolCallBlock(
                        name=tc.get("function", {}).get("name", ""),
                        args=json.loads(tc.get("function", {}).get("arguments", "{}")),
                        id=tc.get("id", ""),
                    )
                )

        logger.debug(
            "Loop done Content: %s  Reasoning: %s Tool Calls: %s",
            accumulated_content,
            accumulated_reasoning_content,
            len(tool_calls),
        )
        message = Message(
            role="assistant",
            message_id=generate_id(None),
            content=blocks,
            delta=False,
            reasoning=accumulated_reasoning_content,
            tools_calls=tool_calls,
            metadata=metadata,
        )
        yield message

    async def convert_streaming_response(  # type: ignore
        self,
        config: dict,
        node_name: str,
        response: Any,
        meta: dict | None = None,
    ) -> AsyncGenerator[Message]:
        """
        Convert a LiteLLM streaming or standard response to Message(s).

        Args:
            config (dict): Node configuration parameters.
            node_name (str): Name of the node processing the response.
            response (Any): The LiteLLM response object (stream or standard).
            meta (dict | None): Optional metadata for conversion.

        Yields:
            Message: Converted message(s) from the response.

        Raises:
            ImportError: If LiteLLM is not installed.
            Exception: If response type is unsupported.
        """
        if not HAS_LITELLM:
            raise ImportError("litellm is not installed. Please install it to use this converter.")

        if isinstance(response, CustomStreamWrapper):  # type: ignore[possibly-unbound]
            stream = cast(CustomStreamWrapper, response)
            async for event in self._handle_stream(
                config or {},
                node_name or "",
                stream,
                meta,
            ):
                yield event
        elif isinstance(response, ModelResponse):  # type: ignore[possibly-unbound]
            message = await self.convert_response(cast(ModelResponse, response))
            yield message
        else:
            raise Exception("Unsupported response type for LiteLLMConverter")
Attributes
state instance-attribute
state = state
Functions
__init__
__init__(state=None)

Initialize the converter.

Parameters:

Name Type Description Default
state AgentState | None

Optional agent state for context during conversion.

None
Source code in pyagenity/adapters/llm/base_converter.py
32
33
34
35
36
37
38
39
def __init__(self, state: AgentState | None = None) -> None:
    """
    Initialize the converter.

    Args:
        state (AgentState | None): Optional agent state for context during conversion.
    """
    self.state = state
convert_response async
convert_response(response)

Convert a LiteLLM ModelResponse to a Message.

Parameters:

Name Type Description Default
response ModelResponse

The LiteLLM model response object.

required

Returns:

Name Type Description
Message Message

The converted message object.

Raises:

Type Description
ImportError

If LiteLLM is not installed.

Source code in pyagenity/adapters/llm/litellm_converter.py
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
async def convert_response(self, response: ModelResponse) -> Message:
    """
    Convert a LiteLLM ModelResponse to a Message.

    Args:
        response (ModelResponse): The LiteLLM model response object.

    Returns:
        Message: The converted message object.

    Raises:
        ImportError: If LiteLLM is not installed.
    """
    if not HAS_LITELLM:
        raise ImportError("litellm is not installed. Please install it to use this converter.")

    data = response.model_dump()

    usages_data = data.get("usage", {})

    usages = TokenUsages(
        completion_tokens=usages_data.get("completion_tokens", 0),
        prompt_tokens=usages_data.get("prompt_tokens", 0),
        total_tokens=usages_data.get("total_tokens", 0),
        cache_creation_input_tokens=usages_data.get("cache_creation_input_tokens", 0),
        cache_read_input_tokens=usages_data.get("cache_read_input_tokens", 0),
        reasoning_tokens=usages_data.get("prompt_tokens_details", {}).get(
            "reasoning_tokens", 0
        ),
    )

    created_date = data.get("created", datetime.now())

    # Extract tool calls from response
    tools_calls = data.get("choices", [{}])[0].get("message", {}).get("tool_calls", []) or []

    logger.debug("Creating message from model response with id: %s", response.id)
    content = data.get("choices", [{}])[0].get("message", {}).get("content", "") or ""
    reasoning_content = (
        data.get("choices", [{}])[0].get("message", {}).get("reasoning_content", "") or ""
    )

    blocks = []
    if content:
        blocks.append(TextBlock(text=content))
    if reasoning_content:
        blocks.append(ReasoningBlock(summary=reasoning_content))
    final_tool_calls = []
    for tool_call in tools_calls:
        tool_id = tool_call.get("id", None)
        args = tool_call.get("function", {}).get("arguments", None)
        name = tool_call.get("function", {}).get("name", None)

        if not tool_id or not args or not name:
            continue

        blocks.append(
            ToolCallBlock(
                name=name,
                args=json.loads(args),
                id=tool_id,
            )
        )

        if hasattr(tool_call, "model_dump"):
            final_tool_calls.append(tool_call.model_dump())
        else:
            final_tool_calls.append(tool_call)

    return Message(
        message_id=generate_id(response.id),
        role="assistant",
        content=blocks,
        reasoning=reasoning_content,
        timestamp=created_date,
        metadata={
            "provider": "litellm",
            "model": data.get("model", ""),
            "finish_reason": data.get("choices", [{}])[0].get("finish_reason", "UNKNOWN"),
            "object": data.get("object", ""),
            "prompt_tokens_details": usages_data.get("prompt_tokens_details", {}),
            "completion_tokens_details": usages_data.get("completion_tokens_details", {}),
        },
        usages=usages,
        raw=data,
        tools_calls=final_tool_calls if final_tool_calls else None,
    )
convert_streaming_response async
convert_streaming_response(config, node_name, response, meta=None)

Convert a LiteLLM streaming or standard response to Message(s).

Parameters:

Name Type Description Default
config dict

Node configuration parameters.

required
node_name str

Name of the node processing the response.

required
response Any

The LiteLLM response object (stream or standard).

required
meta dict | None

Optional metadata for conversion.

None

Yields:

Name Type Description
Message AsyncGenerator[Message]

Converted message(s) from the response.

Raises:

Type Description
ImportError

If LiteLLM is not installed.

Exception

If response type is unsupported.

Source code in pyagenity/adapters/llm/litellm_converter.py
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
async def convert_streaming_response(  # type: ignore
    self,
    config: dict,
    node_name: str,
    response: Any,
    meta: dict | None = None,
) -> AsyncGenerator[Message]:
    """
    Convert a LiteLLM streaming or standard response to Message(s).

    Args:
        config (dict): Node configuration parameters.
        node_name (str): Name of the node processing the response.
        response (Any): The LiteLLM response object (stream or standard).
        meta (dict | None): Optional metadata for conversion.

    Yields:
        Message: Converted message(s) from the response.

    Raises:
        ImportError: If LiteLLM is not installed.
        Exception: If response type is unsupported.
    """
    if not HAS_LITELLM:
        raise ImportError("litellm is not installed. Please install it to use this converter.")

    if isinstance(response, CustomStreamWrapper):  # type: ignore[possibly-unbound]
        stream = cast(CustomStreamWrapper, response)
        async for event in self._handle_stream(
            config or {},
            node_name or "",
            stream,
            meta,
        ):
            yield event
    elif isinstance(response, ModelResponse):  # type: ignore[possibly-unbound]
        message = await self.convert_response(cast(ModelResponse, response))
        yield message
    else:
        raise Exception("Unsupported response type for LiteLLMConverter")
Functions
model_response_converter

Classes:

Name Description
ModelResponseConverter

Wrap an LLM SDK call and normalize its output via a converter.

Classes
ModelResponseConverter

Wrap an LLM SDK call and normalize its output via a converter.

Supports sync/async invocation and streaming. Use invoke() for non-streaming calls and stream() for streaming calls.

Methods:

Name Description
__init__

Initialize ModelResponseConverter.

invoke

Call the underlying function and convert a non-streaming response to Message.

stream

Call the underlying function and yield normalized streaming events and final Message.

Attributes:

Name Type Description
converter
response
Source code in pyagenity/adapters/llm/model_response_converter.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
class ModelResponseConverter:
    """Wrap an LLM SDK call and normalize its output via a converter.

    Supports sync/async invocation and streaming. Use `invoke()` for
    non-streaming calls and `stream()` for streaming calls.
    """

    def __init__(
        self,
        response: Any | Callable[..., Any],
        converter: BaseConverter | str,
    ) -> None:
        """
        Initialize ModelResponseConverter.

        Args:
            response (Any | Callable[..., Any]): The LLM response or a callable returning
                a response.
            converter (BaseConverter | str): Converter instance or string identifier
                (e.g., "litellm").

        Raises:
            ValueError: If the converter is not supported.
        """
        self.response = response

        if isinstance(converter, str) and converter == "litellm":
            from .litellm_converter import LiteLLMConverter

            self.converter = LiteLLMConverter()
        elif isinstance(converter, BaseConverter):
            self.converter = converter
        else:
            raise ValueError(f"Unsupported converter: {converter}")

    async def invoke(self) -> Message:
        """
        Call the underlying function and convert a non-streaming response to Message.

        Returns:
            Message: The normalized message from the LLM response.

        Raises:
            Exception: If the underlying function or converter fails.
        """
        if callable(self.response):
            if inspect.iscoroutinefunction(self.response):
                response = await self.response()
            else:
                response = self.response()
        else:
            response = self.response

        return await self.converter.convert_response(response)  # type: ignore

    async def stream(
        self,
        config: dict,
        node_name: str,
        meta: dict | None = None,
    ) -> AsyncGenerator[Message]:
        """
        Call the underlying function and yield normalized streaming events and final Message.

        Args:
            config (dict): Node configuration parameters for streaming.
            node_name (str): Name of the node processing the response.
            meta (dict | None): Optional metadata for conversion.

        Yields:
            Message: Normalized streaming message events from the LLM response.

        Raises:
            ValueError: If config is not provided.
            Exception: If the underlying function or converter fails.
        """
        if not config:
            raise ValueError("Config must be provided for streaming conversion")

        if callable(self.response):
            if inspect.iscoroutinefunction(self.response):
                response = await self.response()
            else:
                response = self.response()
        else:
            response = self.response

        async for item in self.converter.convert_streaming_response(  # type: ignore
            config,
            node_name=node_name,
            response=response,
            meta=meta,
        ):
            yield item
Attributes
converter instance-attribute
converter = LiteLLMConverter()
response instance-attribute
response = response
Functions
__init__
__init__(response, converter)

Initialize ModelResponseConverter.

Parameters:

Name Type Description Default
response Any | Callable[..., Any]

The LLM response or a callable returning a response.

required
converter BaseConverter | str

Converter instance or string identifier (e.g., "litellm").

required

Raises:

Type Description
ValueError

If the converter is not supported.

Source code in pyagenity/adapters/llm/model_response_converter.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def __init__(
    self,
    response: Any | Callable[..., Any],
    converter: BaseConverter | str,
) -> None:
    """
    Initialize ModelResponseConverter.

    Args:
        response (Any | Callable[..., Any]): The LLM response or a callable returning
            a response.
        converter (BaseConverter | str): Converter instance or string identifier
            (e.g., "litellm").

    Raises:
        ValueError: If the converter is not supported.
    """
    self.response = response

    if isinstance(converter, str) and converter == "litellm":
        from .litellm_converter import LiteLLMConverter

        self.converter = LiteLLMConverter()
    elif isinstance(converter, BaseConverter):
        self.converter = converter
    else:
        raise ValueError(f"Unsupported converter: {converter}")
invoke async
invoke()

Call the underlying function and convert a non-streaming response to Message.

Returns:

Name Type Description
Message Message

The normalized message from the LLM response.

Raises:

Type Description
Exception

If the underlying function or converter fails.

Source code in pyagenity/adapters/llm/model_response_converter.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
async def invoke(self) -> Message:
    """
    Call the underlying function and convert a non-streaming response to Message.

    Returns:
        Message: The normalized message from the LLM response.

    Raises:
        Exception: If the underlying function or converter fails.
    """
    if callable(self.response):
        if inspect.iscoroutinefunction(self.response):
            response = await self.response()
        else:
            response = self.response()
    else:
        response = self.response

    return await self.converter.convert_response(response)  # type: ignore
stream async
stream(config, node_name, meta=None)

Call the underlying function and yield normalized streaming events and final Message.

Parameters:

Name Type Description Default
config dict

Node configuration parameters for streaming.

required
node_name str

Name of the node processing the response.

required
meta dict | None

Optional metadata for conversion.

None

Yields:

Name Type Description
Message AsyncGenerator[Message]

Normalized streaming message events from the LLM response.

Raises:

Type Description
ValueError

If config is not provided.

Exception

If the underlying function or converter fails.

Source code in pyagenity/adapters/llm/model_response_converter.py
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
async def stream(
    self,
    config: dict,
    node_name: str,
    meta: dict | None = None,
) -> AsyncGenerator[Message]:
    """
    Call the underlying function and yield normalized streaming events and final Message.

    Args:
        config (dict): Node configuration parameters for streaming.
        node_name (str): Name of the node processing the response.
        meta (dict | None): Optional metadata for conversion.

    Yields:
        Message: Normalized streaming message events from the LLM response.

    Raises:
        ValueError: If config is not provided.
        Exception: If the underlying function or converter fails.
    """
    if not config:
        raise ValueError("Config must be provided for streaming conversion")

    if callable(self.response):
        if inspect.iscoroutinefunction(self.response):
            response = await self.response()
        else:
            response = self.response()
    else:
        response = self.response

    async for item in self.converter.convert_streaming_response(  # type: ignore
        config,
        node_name=node_name,
        response=response,
        meta=meta,
    ):
        yield item
tools

Integration adapters for optional third-party SDKs.

This module exposes unified wrappers for integrating external tool registries and SDKs with PyAgenity agent graphs. The adapters provide registry-based discovery, function-calling schemas, and normalized execution for supported tool providers.

Exports

ComposioAdapter: Adapter for the Composio Python SDK. LangChainAdapter: Adapter for LangChain tool registry and execution.

Modules:

Name Description
composio_adapter

Composio adapter for PyAgenity.

langchain_adapter

LangChain adapter for PyAgenity (generic wrapper, registry-based).

Classes:

Name Description
ComposioAdapter

Adapter around Composio Python SDK.

LangChainAdapter

Generic registry-based LangChain adapter.

Attributes
__all__ module-attribute
__all__ = ['ComposioAdapter', 'LangChainAdapter']
Classes
ComposioAdapter

Adapter around Composio Python SDK.

Notes on SDK methods used (from docs): - composio.tools.get(user_id=..., tools=[...]/toolkits=[...]/search=..., scopes=..., limit=...) Returns tools formatted for providers or agent frameworks; includes schema. - composio.tools.get_raw_composio_tools(...) Returns raw tool schemas including input_parameters. - composio.tools.execute(slug, arguments, user_id=..., connected_account_id=..., ...) Executes a tool and returns a dict like {data, successful, error}.

Methods:

Name Description
__init__

Initialize the ComposioAdapter.

execute

Execute a Composio tool and return a normalized response dict.

is_available

Return True if composio SDK is importable.

list_raw_tools_for_llm

Return raw Composio tool schemas mapped to function-calling format.

list_tools_for_llm

Return tools formatted for LLM function-calling.

Source code in pyagenity/adapters/tools/composio_adapter.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
class ComposioAdapter:
    """Adapter around Composio Python SDK.

    Notes on SDK methods used (from docs):
    - composio.tools.get(user_id=..., tools=[...]/toolkits=[...]/search=..., scopes=..., limit=...)
        Returns tools formatted for providers or agent frameworks; includes schema.
    - composio.tools.get_raw_composio_tools(...)
        Returns raw tool schemas including input_parameters.
    - composio.tools.execute(slug, arguments, user_id=..., connected_account_id=..., ...)
        Executes a tool and returns a dict like {data, successful, error}.
    """

    def __init__(
        self,
        *,
        api_key: str | None = None,
        provider: t.Any | None = None,
        file_download_dir: str | None = None,
        toolkit_versions: t.Any | None = None,
    ) -> None:
        """
        Initialize the ComposioAdapter.

        Args:
            api_key (str | None): Optional API key for Composio.
            provider (Any | None): Optional provider integration.
            file_download_dir (str | None): Directory for auto file handling.
            toolkit_versions (Any | None): Toolkit version overrides.

        Raises:
            ImportError: If composio SDK is not installed.
        """
        if not HAS_COMPOSIO:
            raise ImportError(
                "ComposioAdapter requires 'composio' package. Install with: "
                "pip install pyagenity[composio]"
            )

        self._composio = Composio(  # type: ignore[call-arg]
            api_key=api_key,
            provider=provider,
            file_download_dir=file_download_dir,
            toolkit_versions=toolkit_versions,
        )

    @staticmethod
    def is_available() -> bool:
        """
        Return True if composio SDK is importable.

        Returns:
            bool: True if composio SDK is available, False otherwise.
        """
        return HAS_COMPOSIO

    def list_tools_for_llm(
        self,
        *,
        user_id: str,
        tool_slugs: list[str] | None = None,
        toolkits: list[str] | None = None,
        search: str | None = None,
        scopes: list[str] | None = None,
        limit: int | None = None,
    ) -> list[dict[str, t.Any]]:
        """
        Return tools formatted for LLM function-calling.

        Args:
            user_id (str): User ID for tool discovery.
            tool_slugs (list[str] | None): Optional list of tool slugs.
            toolkits (list[str] | None): Optional list of toolkits.
            search (str | None): Optional search string.
            scopes (list[str] | None): Optional scopes.
            limit (int | None): Optional limit on number of tools.

        Returns:
            list[dict[str, Any]]: List of tools in function-calling format.
        """
        # Prefer the provider-wrapped format when available
        tools = self._composio.tools.get(
            user_id=user_id,
            tools=tool_slugs,  # type: ignore[arg-type]
            toolkits=toolkits,  # type: ignore[arg-type]
            search=search,
            scopes=scopes,
            limit=limit,
        )

        # The provider-wrapped output may already be in the desired structure.
        # We'll detect and pass-through; otherwise convert using raw schemas.
        formatted: list[dict[str, t.Any]] = []
        for t_obj in tools if isinstance(tools, list) else []:
            try:
                if (
                    isinstance(t_obj, dict)
                    and t_obj.get("type") == "function"
                    and "function" in t_obj
                ):
                    formatted.append(t_obj)
                else:
                    # Fallback: try to pull minimal fields
                    fn = t_obj.get("function", {}) if isinstance(t_obj, dict) else {}
                    if fn.get("name") and fn.get("parameters"):
                        formatted.append({"type": "function", "function": fn})
            except Exception as exc:
                logger.debug("Skipping non-conforming Composio tool wrapper: %s", exc)
                continue

        if formatted:
            return formatted

        # Fallback to raw schemas and convert manually
        formatted.extend(
            self.list_raw_tools_for_llm(
                tool_slugs=tool_slugs, toolkits=toolkits, search=search, scopes=scopes, limit=limit
            )
        )

        return formatted

    def list_raw_tools_for_llm(
        self,
        *,
        tool_slugs: list[str] | None = None,
        toolkits: list[str] | None = None,
        search: str | None = None,
        scopes: list[str] | None = None,
        limit: int | None = None,
    ) -> list[dict[str, t.Any]]:
        """
        Return raw Composio tool schemas mapped to function-calling format.

        Args:
            tool_slugs (list[str] | None): Optional list of tool slugs.
            toolkits (list[str] | None): Optional list of toolkits.
            search (str | None): Optional search string.
            scopes (list[str] | None): Optional scopes.
            limit (int | None): Optional limit on number of tools.

        Returns:
            list[dict[str, Any]]: List of raw tool schemas in function-calling format.
        """
        formatted: list[dict[str, t.Any]] = []
        raw_tools = self._composio.tools.get_raw_composio_tools(
            tools=tool_slugs, search=search, toolkits=toolkits, scopes=scopes, limit=limit
        )

        for tool in raw_tools:
            try:
                name = tool.slug  # type: ignore[attr-defined]
                description = getattr(tool, "description", "") or "Composio tool"
                params = getattr(tool, "input_parameters", None)
                if not params:
                    # Minimal shape if schema missing
                    params = {"type": "object", "properties": {}}
                formatted.append(
                    {
                        "type": "function",
                        "function": {
                            "name": name,
                            "description": description,
                            "parameters": params,
                        },
                    }
                )
            except Exception as e:
                logger.warning("Failed to map Composio tool schema: %s", e)
                continue
        return formatted

    def execute(
        self,
        *,
        slug: str,
        arguments: dict[str, t.Any],
        user_id: str | None = None,
        connected_account_id: str | None = None,
        custom_auth_params: dict[str, t.Any] | None = None,
        custom_connection_data: dict[str, t.Any] | None = None,
        text: str | None = None,
        version: str | None = None,
        toolkit_versions: t.Any | None = None,
        modifiers: t.Any | None = None,
    ) -> dict[str, t.Any]:
        """
        Execute a Composio tool and return a normalized response dict.

        Args:
            slug (str): Tool slug to execute.
            arguments (dict[str, Any]): Arguments for the tool.
            user_id (str | None): Optional user ID.
            connected_account_id (str | None): Optional connected account ID.
            custom_auth_params (dict[str, Any] | None): Optional custom auth params.
            custom_connection_data (dict[str, Any] | None): Optional custom connection data.
            text (str | None): Optional text input.
            version (str | None): Optional version.
            toolkit_versions (Any | None): Optional toolkit versions.
            modifiers (Any | None): Optional modifiers.

        Returns:
            dict[str, Any]: Normalized response dict with keys: successful, data, error.
        """
        resp = self._composio.tools.execute(
            slug=slug,
            arguments=arguments,
            user_id=user_id,
            connected_account_id=connected_account_id,
            custom_auth_params=custom_auth_params,
            custom_connection_data=custom_connection_data,
            text=text,
            version=version,
            toolkit_versions=toolkit_versions,
            modifiers=modifiers,
        )

        # The SDK returns a TypedDict-like object; ensure plain dict
        if hasattr(resp, "copy") and not isinstance(resp, dict):  # e.g., TypedDict proxy
            try:
                resp = dict(resp)  # type: ignore[assignment]
            except Exception as exc:
                logger.debug("Could not coerce Composio response to dict: %s", exc)

        # Normalize key presence
        successful = bool(resp.get("successful", False))  # type: ignore[arg-type]
        data = resp.get("data")
        error = resp.get("error")
        return {"successful": successful, "data": data, "error": error}
Functions
__init__
__init__(*, api_key=None, provider=None, file_download_dir=None, toolkit_versions=None)

Initialize the ComposioAdapter.

Parameters:

Name Type Description Default
api_key str | None

Optional API key for Composio.

None
provider Any | None

Optional provider integration.

None
file_download_dir str | None

Directory for auto file handling.

None
toolkit_versions Any | None

Toolkit version overrides.

None

Raises:

Type Description
ImportError

If composio SDK is not installed.

Source code in pyagenity/adapters/tools/composio_adapter.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def __init__(
    self,
    *,
    api_key: str | None = None,
    provider: t.Any | None = None,
    file_download_dir: str | None = None,
    toolkit_versions: t.Any | None = None,
) -> None:
    """
    Initialize the ComposioAdapter.

    Args:
        api_key (str | None): Optional API key for Composio.
        provider (Any | None): Optional provider integration.
        file_download_dir (str | None): Directory for auto file handling.
        toolkit_versions (Any | None): Toolkit version overrides.

    Raises:
        ImportError: If composio SDK is not installed.
    """
    if not HAS_COMPOSIO:
        raise ImportError(
            "ComposioAdapter requires 'composio' package. Install with: "
            "pip install pyagenity[composio]"
        )

    self._composio = Composio(  # type: ignore[call-arg]
        api_key=api_key,
        provider=provider,
        file_download_dir=file_download_dir,
        toolkit_versions=toolkit_versions,
    )
execute
execute(*, slug, arguments, user_id=None, connected_account_id=None, custom_auth_params=None, custom_connection_data=None, text=None, version=None, toolkit_versions=None, modifiers=None)

Execute a Composio tool and return a normalized response dict.

Parameters:

Name Type Description Default
slug str

Tool slug to execute.

required
arguments dict[str, Any]

Arguments for the tool.

required
user_id str | None

Optional user ID.

None
connected_account_id str | None

Optional connected account ID.

None
custom_auth_params dict[str, Any] | None

Optional custom auth params.

None
custom_connection_data dict[str, Any] | None

Optional custom connection data.

None
text str | None

Optional text input.

None
version str | None

Optional version.

None
toolkit_versions Any | None

Optional toolkit versions.

None
modifiers Any | None

Optional modifiers.

None

Returns:

Type Description
dict[str, Any]

dict[str, Any]: Normalized response dict with keys: successful, data, error.

Source code in pyagenity/adapters/tools/composio_adapter.py
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
def execute(
    self,
    *,
    slug: str,
    arguments: dict[str, t.Any],
    user_id: str | None = None,
    connected_account_id: str | None = None,
    custom_auth_params: dict[str, t.Any] | None = None,
    custom_connection_data: dict[str, t.Any] | None = None,
    text: str | None = None,
    version: str | None = None,
    toolkit_versions: t.Any | None = None,
    modifiers: t.Any | None = None,
) -> dict[str, t.Any]:
    """
    Execute a Composio tool and return a normalized response dict.

    Args:
        slug (str): Tool slug to execute.
        arguments (dict[str, Any]): Arguments for the tool.
        user_id (str | None): Optional user ID.
        connected_account_id (str | None): Optional connected account ID.
        custom_auth_params (dict[str, Any] | None): Optional custom auth params.
        custom_connection_data (dict[str, Any] | None): Optional custom connection data.
        text (str | None): Optional text input.
        version (str | None): Optional version.
        toolkit_versions (Any | None): Optional toolkit versions.
        modifiers (Any | None): Optional modifiers.

    Returns:
        dict[str, Any]: Normalized response dict with keys: successful, data, error.
    """
    resp = self._composio.tools.execute(
        slug=slug,
        arguments=arguments,
        user_id=user_id,
        connected_account_id=connected_account_id,
        custom_auth_params=custom_auth_params,
        custom_connection_data=custom_connection_data,
        text=text,
        version=version,
        toolkit_versions=toolkit_versions,
        modifiers=modifiers,
    )

    # The SDK returns a TypedDict-like object; ensure plain dict
    if hasattr(resp, "copy") and not isinstance(resp, dict):  # e.g., TypedDict proxy
        try:
            resp = dict(resp)  # type: ignore[assignment]
        except Exception as exc:
            logger.debug("Could not coerce Composio response to dict: %s", exc)

    # Normalize key presence
    successful = bool(resp.get("successful", False))  # type: ignore[arg-type]
    data = resp.get("data")
    error = resp.get("error")
    return {"successful": successful, "data": data, "error": error}
is_available staticmethod
is_available()

Return True if composio SDK is importable.

Returns:

Name Type Description
bool bool

True if composio SDK is available, False otherwise.

Source code in pyagenity/adapters/tools/composio_adapter.py
83
84
85
86
87
88
89
90
91
@staticmethod
def is_available() -> bool:
    """
    Return True if composio SDK is importable.

    Returns:
        bool: True if composio SDK is available, False otherwise.
    """
    return HAS_COMPOSIO
list_raw_tools_for_llm
list_raw_tools_for_llm(*, tool_slugs=None, toolkits=None, search=None, scopes=None, limit=None)

Return raw Composio tool schemas mapped to function-calling format.

Parameters:

Name Type Description Default
tool_slugs list[str] | None

Optional list of tool slugs.

None
toolkits list[str] | None

Optional list of toolkits.

None
search str | None

Optional search string.

None
scopes list[str] | None

Optional scopes.

None
limit int | None

Optional limit on number of tools.

None

Returns:

Type Description
list[dict[str, Any]]

list[dict[str, Any]]: List of raw tool schemas in function-calling format.

Source code in pyagenity/adapters/tools/composio_adapter.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
def list_raw_tools_for_llm(
    self,
    *,
    tool_slugs: list[str] | None = None,
    toolkits: list[str] | None = None,
    search: str | None = None,
    scopes: list[str] | None = None,
    limit: int | None = None,
) -> list[dict[str, t.Any]]:
    """
    Return raw Composio tool schemas mapped to function-calling format.

    Args:
        tool_slugs (list[str] | None): Optional list of tool slugs.
        toolkits (list[str] | None): Optional list of toolkits.
        search (str | None): Optional search string.
        scopes (list[str] | None): Optional scopes.
        limit (int | None): Optional limit on number of tools.

    Returns:
        list[dict[str, Any]]: List of raw tool schemas in function-calling format.
    """
    formatted: list[dict[str, t.Any]] = []
    raw_tools = self._composio.tools.get_raw_composio_tools(
        tools=tool_slugs, search=search, toolkits=toolkits, scopes=scopes, limit=limit
    )

    for tool in raw_tools:
        try:
            name = tool.slug  # type: ignore[attr-defined]
            description = getattr(tool, "description", "") or "Composio tool"
            params = getattr(tool, "input_parameters", None)
            if not params:
                # Minimal shape if schema missing
                params = {"type": "object", "properties": {}}
            formatted.append(
                {
                    "type": "function",
                    "function": {
                        "name": name,
                        "description": description,
                        "parameters": params,
                    },
                }
            )
        except Exception as e:
            logger.warning("Failed to map Composio tool schema: %s", e)
            continue
    return formatted
list_tools_for_llm
list_tools_for_llm(*, user_id, tool_slugs=None, toolkits=None, search=None, scopes=None, limit=None)

Return tools formatted for LLM function-calling.

Parameters:

Name Type Description Default
user_id str

User ID for tool discovery.

required
tool_slugs list[str] | None

Optional list of tool slugs.

None
toolkits list[str] | None

Optional list of toolkits.

None
search str | None

Optional search string.

None
scopes list[str] | None

Optional scopes.

None
limit int | None

Optional limit on number of tools.

None

Returns:

Type Description
list[dict[str, Any]]

list[dict[str, Any]]: List of tools in function-calling format.

Source code in pyagenity/adapters/tools/composio_adapter.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def list_tools_for_llm(
    self,
    *,
    user_id: str,
    tool_slugs: list[str] | None = None,
    toolkits: list[str] | None = None,
    search: str | None = None,
    scopes: list[str] | None = None,
    limit: int | None = None,
) -> list[dict[str, t.Any]]:
    """
    Return tools formatted for LLM function-calling.

    Args:
        user_id (str): User ID for tool discovery.
        tool_slugs (list[str] | None): Optional list of tool slugs.
        toolkits (list[str] | None): Optional list of toolkits.
        search (str | None): Optional search string.
        scopes (list[str] | None): Optional scopes.
        limit (int | None): Optional limit on number of tools.

    Returns:
        list[dict[str, Any]]: List of tools in function-calling format.
    """
    # Prefer the provider-wrapped format when available
    tools = self._composio.tools.get(
        user_id=user_id,
        tools=tool_slugs,  # type: ignore[arg-type]
        toolkits=toolkits,  # type: ignore[arg-type]
        search=search,
        scopes=scopes,
        limit=limit,
    )

    # The provider-wrapped output may already be in the desired structure.
    # We'll detect and pass-through; otherwise convert using raw schemas.
    formatted: list[dict[str, t.Any]] = []
    for t_obj in tools if isinstance(tools, list) else []:
        try:
            if (
                isinstance(t_obj, dict)
                and t_obj.get("type") == "function"
                and "function" in t_obj
            ):
                formatted.append(t_obj)
            else:
                # Fallback: try to pull minimal fields
                fn = t_obj.get("function", {}) if isinstance(t_obj, dict) else {}
                if fn.get("name") and fn.get("parameters"):
                    formatted.append({"type": "function", "function": fn})
        except Exception as exc:
            logger.debug("Skipping non-conforming Composio tool wrapper: %s", exc)
            continue

    if formatted:
        return formatted

    # Fallback to raw schemas and convert manually
    formatted.extend(
        self.list_raw_tools_for_llm(
            tool_slugs=tool_slugs, toolkits=toolkits, search=search, scopes=scopes, limit=limit
        )
    )

    return formatted
LangChainAdapter

Generic registry-based LangChain adapter.

Notes
  • Avoids importing heavy integrations until needed (lazy default autoload).
  • Normalizes schemas and execution results into simple dicts.
  • Allows arbitrary tool registration instead of hardcoding a tiny set.

Methods:

Name Description
__init__

Initialize LangChainAdapter.

execute

Execute a supported LangChain tool and normalize the response.

is_available

Return True if langchain-core is importable.

list_tools_for_llm

Return a list of function-calling formatted tool schemas.

register_tool

Register a tool instance and return the resolved name used for exposure.

register_tools

Register multiple tool instances.

Source code in pyagenity/adapters/tools/langchain_adapter.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
class LangChainAdapter:
    """
    Generic registry-based LangChain adapter.

    Notes:
        - Avoids importing heavy integrations until needed (lazy default autoload).
        - Normalizes schemas and execution results into simple dicts.
        - Allows arbitrary tool registration instead of hardcoding a tiny set.
    """

    def __init__(self, *, autoload_default_tools: bool = True) -> None:
        """
        Initialize LangChainAdapter.

        Args:
            autoload_default_tools (bool): Whether to autoload default tools if registry is empty.

        Raises:
            ImportError: If langchain-core is not installed.
        """
        if not HAS_LANGCHAIN:
            raise ImportError(
                "LangChainAdapter requires 'langchain-core' and optional integrations.\n"
                "Install with: pip install pyagenity[langchain]"
            )
        self._registry: dict[str, LangChainToolWrapper] = {}
        self._autoload = autoload_default_tools

    @staticmethod
    def is_available() -> bool:
        """
        Return True if langchain-core is importable.

        Returns:
            bool: True if langchain-core is available, False otherwise.
        """
        return HAS_LANGCHAIN

    # ------------------------
    # Discovery
    # ------------------------
    def list_tools_for_llm(self) -> list[dict[str, t.Any]]:
        """
        Return a list of function-calling formatted tool schemas.

        If registry is empty and autoload is enabled, attempt to autoload a
        couple of common tools for convenience (tavily_search, requests_get).

        Returns:
            list[dict[str, Any]]: List of tool schemas in function-calling format.
        """
        if not self._registry and self._autoload:
            self._try_autoload_defaults()

        return [wrapper.to_schema() for wrapper in self._registry.values()]

    # ------------------------
    # Execute
    # ------------------------
    def execute(self, *, name: str, arguments: dict[str, t.Any]) -> dict[str, t.Any]:
        """
        Execute a supported LangChain tool and normalize the response.

        Args:
            name (str): Name of the tool to execute.
            arguments (dict[str, Any]): Arguments for the tool.

        Returns:
            dict[str, Any]: Normalized response dict with keys: successful, data, error.
        """
        if name not in self._registry and self._autoload:
            # Late autoload attempt in case discovery wasn't called first
            self._try_autoload_defaults()

        wrapper = self._registry.get(name)
        if not wrapper:
            return {"successful": False, "data": None, "error": f"Unknown LangChain tool: {name}"}
        return wrapper.execute(arguments)

    # ------------------------
    # Internals
    # ------------------------
    def register_tool(
        self,
        tool: t.Any,
        *,
        name: str | None = None,
        description: str | None = None,
    ) -> str:
        """
        Register a tool instance and return the resolved name used for exposure.

        Args:
            tool (Any): Tool instance to register.
            name (str | None): Optional override for tool name.
            description (str | None): Optional override for tool description.

        Returns:
            str: The resolved name used for exposure.
        """
        wrapper = LangChainToolWrapper(tool, name=name, description=description)
        self._registry[wrapper.name] = wrapper
        return wrapper.name

    def register_tools(self, tools: list[t.Any]) -> list[str]:
        """
        Register multiple tool instances.

        Args:
            tools (list[Any]): List of tool instances to register.

        Returns:
            list[str]: List of resolved names for the registered tools.
        """
        names: list[str] = []
        for tool in tools:
            names.append(self.register_tool(tool))
        return names

    def _create_tavily_search_tool(self) -> t.Any:
        """
        Construct Tavily search tool lazily.

        Prefer the new dedicated integration `langchain_tavily.TavilySearch`.
        Fall back to the deprecated community tool if needed.

        Returns:
            Any: Tavily search tool instance.

        Raises:
            ImportError: If Tavily tool cannot be imported.
        """
        # Preferred: langchain-tavily
        try:
            mod = importlib.import_module("langchain_tavily")
            return mod.TavilySearch()  # type: ignore[attr-defined]
        except Exception as exc:
            logger.debug("Preferred langchain_tavily import failed: %s", exc)

        # Fallback: deprecated community tool (still functional for now)
        try:
            mod = importlib.import_module("langchain_community.tools.tavily_search")
            return mod.TavilySearchResults()
        except Exception as exc:  # ImportError or runtime
            raise ImportError(
                "Tavily tool requires 'langchain-tavily' (preferred) or"
                " 'langchain-community' with 'tavily-python'.\n"
                "Install with: pip install pyagenity[langchain]"
            ) from exc

    def _create_requests_get_tool(self) -> t.Any:
        """
        Construct RequestsGetTool lazily with a basic requests wrapper.

        Note: Requests tools require an explicit wrapper instance and, for safety,
        default to disallowing dangerous requests. Here we opt-in to allow GET
        requests by setting allow_dangerous_requests=True to make the tool usable
        in agent contexts. Consider tightening this in your application.

        Returns:
            Any: RequestsGetTool instance.

        Raises:
            ImportError: If RequestsGetTool cannot be imported.
        """
        try:
            req_tool_mod = importlib.import_module("langchain_community.tools.requests.tool")
            util_mod = importlib.import_module("langchain_community.utilities.requests")
            wrapper = util_mod.TextRequestsWrapper(headers={})  # type: ignore[attr-defined]
            return req_tool_mod.RequestsGetTool(
                requests_wrapper=wrapper,
                allow_dangerous_requests=True,
            )
        except Exception as exc:  # ImportError or runtime
            raise ImportError(
                "Requests tool requires 'langchain-community'.\n"
                "Install with: pip install pyagenity[langchain]"
            ) from exc

    def _try_autoload_defaults(self) -> None:
        """
        Best-effort autoload of a couple of common tools.

        This keeps prior behavior available while allowing users to register
        arbitrary tools. Failures are logged but non-fatal.

        Returns:
            None
        """
        # Tavily search
        try:
            tavily = self._create_tavily_search_tool()
            self.register_tool(tavily, name="tavily_search")
        except Exception as exc:
            logger.debug("Skipping Tavily autoload: %s", exc)

        # Requests GET
        try:
            rget = self._create_requests_get_tool()
            self.register_tool(rget, name="requests_get")
        except Exception as exc:
            logger.debug("Skipping requests_get autoload: %s", exc)
Functions
__init__
__init__(*, autoload_default_tools=True)

Initialize LangChainAdapter.

Parameters:

Name Type Description Default
autoload_default_tools bool

Whether to autoload default tools if registry is empty.

True

Raises:

Type Description
ImportError

If langchain-core is not installed.

Source code in pyagenity/adapters/tools/langchain_adapter.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
def __init__(self, *, autoload_default_tools: bool = True) -> None:
    """
    Initialize LangChainAdapter.

    Args:
        autoload_default_tools (bool): Whether to autoload default tools if registry is empty.

    Raises:
        ImportError: If langchain-core is not installed.
    """
    if not HAS_LANGCHAIN:
        raise ImportError(
            "LangChainAdapter requires 'langchain-core' and optional integrations.\n"
            "Install with: pip install pyagenity[langchain]"
        )
    self._registry: dict[str, LangChainToolWrapper] = {}
    self._autoload = autoload_default_tools
execute
execute(*, name, arguments)

Execute a supported LangChain tool and normalize the response.

Parameters:

Name Type Description Default
name str

Name of the tool to execute.

required
arguments dict[str, Any]

Arguments for the tool.

required

Returns:

Type Description
dict[str, Any]

dict[str, Any]: Normalized response dict with keys: successful, data, error.

Source code in pyagenity/adapters/tools/langchain_adapter.py
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
def execute(self, *, name: str, arguments: dict[str, t.Any]) -> dict[str, t.Any]:
    """
    Execute a supported LangChain tool and normalize the response.

    Args:
        name (str): Name of the tool to execute.
        arguments (dict[str, Any]): Arguments for the tool.

    Returns:
        dict[str, Any]: Normalized response dict with keys: successful, data, error.
    """
    if name not in self._registry and self._autoload:
        # Late autoload attempt in case discovery wasn't called first
        self._try_autoload_defaults()

    wrapper = self._registry.get(name)
    if not wrapper:
        return {"successful": False, "data": None, "error": f"Unknown LangChain tool: {name}"}
    return wrapper.execute(arguments)
is_available staticmethod
is_available()

Return True if langchain-core is importable.

Returns:

Name Type Description
bool bool

True if langchain-core is available, False otherwise.

Source code in pyagenity/adapters/tools/langchain_adapter.py
257
258
259
260
261
262
263
264
265
@staticmethod
def is_available() -> bool:
    """
    Return True if langchain-core is importable.

    Returns:
        bool: True if langchain-core is available, False otherwise.
    """
    return HAS_LANGCHAIN
list_tools_for_llm
list_tools_for_llm()

Return a list of function-calling formatted tool schemas.

If registry is empty and autoload is enabled, attempt to autoload a couple of common tools for convenience (tavily_search, requests_get).

Returns:

Type Description
list[dict[str, Any]]

list[dict[str, Any]]: List of tool schemas in function-calling format.

Source code in pyagenity/adapters/tools/langchain_adapter.py
270
271
272
273
274
275
276
277
278
279
280
281
282
283
def list_tools_for_llm(self) -> list[dict[str, t.Any]]:
    """
    Return a list of function-calling formatted tool schemas.

    If registry is empty and autoload is enabled, attempt to autoload a
    couple of common tools for convenience (tavily_search, requests_get).

    Returns:
        list[dict[str, Any]]: List of tool schemas in function-calling format.
    """
    if not self._registry and self._autoload:
        self._try_autoload_defaults()

    return [wrapper.to_schema() for wrapper in self._registry.values()]
register_tool
register_tool(tool, *, name=None, description=None)

Register a tool instance and return the resolved name used for exposure.

Parameters:

Name Type Description Default
tool Any

Tool instance to register.

required
name str | None

Optional override for tool name.

None
description str | None

Optional override for tool description.

None

Returns:

Name Type Description
str str

The resolved name used for exposure.

Source code in pyagenity/adapters/tools/langchain_adapter.py
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
def register_tool(
    self,
    tool: t.Any,
    *,
    name: str | None = None,
    description: str | None = None,
) -> str:
    """
    Register a tool instance and return the resolved name used for exposure.

    Args:
        tool (Any): Tool instance to register.
        name (str | None): Optional override for tool name.
        description (str | None): Optional override for tool description.

    Returns:
        str: The resolved name used for exposure.
    """
    wrapper = LangChainToolWrapper(tool, name=name, description=description)
    self._registry[wrapper.name] = wrapper
    return wrapper.name
register_tools
register_tools(tools)

Register multiple tool instances.

Parameters:

Name Type Description Default
tools list[Any]

List of tool instances to register.

required

Returns:

Type Description
list[str]

list[str]: List of resolved names for the registered tools.

Source code in pyagenity/adapters/tools/langchain_adapter.py
333
334
335
336
337
338
339
340
341
342
343
344
345
346
def register_tools(self, tools: list[t.Any]) -> list[str]:
    """
    Register multiple tool instances.

    Args:
        tools (list[Any]): List of tool instances to register.

    Returns:
        list[str]: List of resolved names for the registered tools.
    """
    names: list[str] = []
    for tool in tools:
        names.append(self.register_tool(tool))
    return names
Modules
composio_adapter

Composio adapter for PyAgenity.

This module provides a thin wrapper around the Composio Python SDK to: - Fetch tools formatted for LLM function calling (matching ToolNode format) - Execute Composio tools directly

The dependency is optional. Install with: pip install pyagenity[composio]

Usage outline

adapter = ComposioAdapter(api_key=os.environ["COMPOSIO_API_KEY"]) # optional key result = adapter.execute( slug="GITHUB_LIST_STARGAZERS", arguments={"owner": "ComposioHQ", "repo": "composio"}, user_id="user-123", )

Classes:

Name Description
ComposioAdapter

Adapter around Composio Python SDK.

Attributes:

Name Type Description
HAS_COMPOSIO
logger
Attributes
HAS_COMPOSIO module-attribute
HAS_COMPOSIO = True
logger module-attribute
logger = getLogger(__name__)
Classes
ComposioAdapter

Adapter around Composio Python SDK.

Notes on SDK methods used (from docs): - composio.tools.get(user_id=..., tools=[...]/toolkits=[...]/search=..., scopes=..., limit=...) Returns tools formatted for providers or agent frameworks; includes schema. - composio.tools.get_raw_composio_tools(...) Returns raw tool schemas including input_parameters. - composio.tools.execute(slug, arguments, user_id=..., connected_account_id=..., ...) Executes a tool and returns a dict like {data, successful, error}.

Methods:

Name Description
__init__

Initialize the ComposioAdapter.

execute

Execute a Composio tool and return a normalized response dict.

is_available

Return True if composio SDK is importable.

list_raw_tools_for_llm

Return raw Composio tool schemas mapped to function-calling format.

list_tools_for_llm

Return tools formatted for LLM function-calling.

Source code in pyagenity/adapters/tools/composio_adapter.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
class ComposioAdapter:
    """Adapter around Composio Python SDK.

    Notes on SDK methods used (from docs):
    - composio.tools.get(user_id=..., tools=[...]/toolkits=[...]/search=..., scopes=..., limit=...)
        Returns tools formatted for providers or agent frameworks; includes schema.
    - composio.tools.get_raw_composio_tools(...)
        Returns raw tool schemas including input_parameters.
    - composio.tools.execute(slug, arguments, user_id=..., connected_account_id=..., ...)
        Executes a tool and returns a dict like {data, successful, error}.
    """

    def __init__(
        self,
        *,
        api_key: str | None = None,
        provider: t.Any | None = None,
        file_download_dir: str | None = None,
        toolkit_versions: t.Any | None = None,
    ) -> None:
        """
        Initialize the ComposioAdapter.

        Args:
            api_key (str | None): Optional API key for Composio.
            provider (Any | None): Optional provider integration.
            file_download_dir (str | None): Directory for auto file handling.
            toolkit_versions (Any | None): Toolkit version overrides.

        Raises:
            ImportError: If composio SDK is not installed.
        """
        if not HAS_COMPOSIO:
            raise ImportError(
                "ComposioAdapter requires 'composio' package. Install with: "
                "pip install pyagenity[composio]"
            )

        self._composio = Composio(  # type: ignore[call-arg]
            api_key=api_key,
            provider=provider,
            file_download_dir=file_download_dir,
            toolkit_versions=toolkit_versions,
        )

    @staticmethod
    def is_available() -> bool:
        """
        Return True if composio SDK is importable.

        Returns:
            bool: True if composio SDK is available, False otherwise.
        """
        return HAS_COMPOSIO

    def list_tools_for_llm(
        self,
        *,
        user_id: str,
        tool_slugs: list[str] | None = None,
        toolkits: list[str] | None = None,
        search: str | None = None,
        scopes: list[str] | None = None,
        limit: int | None = None,
    ) -> list[dict[str, t.Any]]:
        """
        Return tools formatted for LLM function-calling.

        Args:
            user_id (str): User ID for tool discovery.
            tool_slugs (list[str] | None): Optional list of tool slugs.
            toolkits (list[str] | None): Optional list of toolkits.
            search (str | None): Optional search string.
            scopes (list[str] | None): Optional scopes.
            limit (int | None): Optional limit on number of tools.

        Returns:
            list[dict[str, Any]]: List of tools in function-calling format.
        """
        # Prefer the provider-wrapped format when available
        tools = self._composio.tools.get(
            user_id=user_id,
            tools=tool_slugs,  # type: ignore[arg-type]
            toolkits=toolkits,  # type: ignore[arg-type]
            search=search,
            scopes=scopes,
            limit=limit,
        )

        # The provider-wrapped output may already be in the desired structure.
        # We'll detect and pass-through; otherwise convert using raw schemas.
        formatted: list[dict[str, t.Any]] = []
        for t_obj in tools if isinstance(tools, list) else []:
            try:
                if (
                    isinstance(t_obj, dict)
                    and t_obj.get("type") == "function"
                    and "function" in t_obj
                ):
                    formatted.append(t_obj)
                else:
                    # Fallback: try to pull minimal fields
                    fn = t_obj.get("function", {}) if isinstance(t_obj, dict) else {}
                    if fn.get("name") and fn.get("parameters"):
                        formatted.append({"type": "function", "function": fn})
            except Exception as exc:
                logger.debug("Skipping non-conforming Composio tool wrapper: %s", exc)
                continue

        if formatted:
            return formatted

        # Fallback to raw schemas and convert manually
        formatted.extend(
            self.list_raw_tools_for_llm(
                tool_slugs=tool_slugs, toolkits=toolkits, search=search, scopes=scopes, limit=limit
            )
        )

        return formatted

    def list_raw_tools_for_llm(
        self,
        *,
        tool_slugs: list[str] | None = None,
        toolkits: list[str] | None = None,
        search: str | None = None,
        scopes: list[str] | None = None,
        limit: int | None = None,
    ) -> list[dict[str, t.Any]]:
        """
        Return raw Composio tool schemas mapped to function-calling format.

        Args:
            tool_slugs (list[str] | None): Optional list of tool slugs.
            toolkits (list[str] | None): Optional list of toolkits.
            search (str | None): Optional search string.
            scopes (list[str] | None): Optional scopes.
            limit (int | None): Optional limit on number of tools.

        Returns:
            list[dict[str, Any]]: List of raw tool schemas in function-calling format.
        """
        formatted: list[dict[str, t.Any]] = []
        raw_tools = self._composio.tools.get_raw_composio_tools(
            tools=tool_slugs, search=search, toolkits=toolkits, scopes=scopes, limit=limit
        )

        for tool in raw_tools:
            try:
                name = tool.slug  # type: ignore[attr-defined]
                description = getattr(tool, "description", "") or "Composio tool"
                params = getattr(tool, "input_parameters", None)
                if not params:
                    # Minimal shape if schema missing
                    params = {"type": "object", "properties": {}}
                formatted.append(
                    {
                        "type": "function",
                        "function": {
                            "name": name,
                            "description": description,
                            "parameters": params,
                        },
                    }
                )
            except Exception as e:
                logger.warning("Failed to map Composio tool schema: %s", e)
                continue
        return formatted

    def execute(
        self,
        *,
        slug: str,
        arguments: dict[str, t.Any],
        user_id: str | None = None,
        connected_account_id: str | None = None,
        custom_auth_params: dict[str, t.Any] | None = None,
        custom_connection_data: dict[str, t.Any] | None = None,
        text: str | None = None,
        version: str | None = None,
        toolkit_versions: t.Any | None = None,
        modifiers: t.Any | None = None,
    ) -> dict[str, t.Any]:
        """
        Execute a Composio tool and return a normalized response dict.

        Args:
            slug (str): Tool slug to execute.
            arguments (dict[str, Any]): Arguments for the tool.
            user_id (str | None): Optional user ID.
            connected_account_id (str | None): Optional connected account ID.
            custom_auth_params (dict[str, Any] | None): Optional custom auth params.
            custom_connection_data (dict[str, Any] | None): Optional custom connection data.
            text (str | None): Optional text input.
            version (str | None): Optional version.
            toolkit_versions (Any | None): Optional toolkit versions.
            modifiers (Any | None): Optional modifiers.

        Returns:
            dict[str, Any]: Normalized response dict with keys: successful, data, error.
        """
        resp = self._composio.tools.execute(
            slug=slug,
            arguments=arguments,
            user_id=user_id,
            connected_account_id=connected_account_id,
            custom_auth_params=custom_auth_params,
            custom_connection_data=custom_connection_data,
            text=text,
            version=version,
            toolkit_versions=toolkit_versions,
            modifiers=modifiers,
        )

        # The SDK returns a TypedDict-like object; ensure plain dict
        if hasattr(resp, "copy") and not isinstance(resp, dict):  # e.g., TypedDict proxy
            try:
                resp = dict(resp)  # type: ignore[assignment]
            except Exception as exc:
                logger.debug("Could not coerce Composio response to dict: %s", exc)

        # Normalize key presence
        successful = bool(resp.get("successful", False))  # type: ignore[arg-type]
        data = resp.get("data")
        error = resp.get("error")
        return {"successful": successful, "data": data, "error": error}
Functions
__init__
__init__(*, api_key=None, provider=None, file_download_dir=None, toolkit_versions=None)

Initialize the ComposioAdapter.

Parameters:

Name Type Description Default
api_key str | None

Optional API key for Composio.

None
provider Any | None

Optional provider integration.

None
file_download_dir str | None

Directory for auto file handling.

None
toolkit_versions Any | None

Toolkit version overrides.

None

Raises:

Type Description
ImportError

If composio SDK is not installed.

Source code in pyagenity/adapters/tools/composio_adapter.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def __init__(
    self,
    *,
    api_key: str | None = None,
    provider: t.Any | None = None,
    file_download_dir: str | None = None,
    toolkit_versions: t.Any | None = None,
) -> None:
    """
    Initialize the ComposioAdapter.

    Args:
        api_key (str | None): Optional API key for Composio.
        provider (Any | None): Optional provider integration.
        file_download_dir (str | None): Directory for auto file handling.
        toolkit_versions (Any | None): Toolkit version overrides.

    Raises:
        ImportError: If composio SDK is not installed.
    """
    if not HAS_COMPOSIO:
        raise ImportError(
            "ComposioAdapter requires 'composio' package. Install with: "
            "pip install pyagenity[composio]"
        )

    self._composio = Composio(  # type: ignore[call-arg]
        api_key=api_key,
        provider=provider,
        file_download_dir=file_download_dir,
        toolkit_versions=toolkit_versions,
    )
execute
execute(*, slug, arguments, user_id=None, connected_account_id=None, custom_auth_params=None, custom_connection_data=None, text=None, version=None, toolkit_versions=None, modifiers=None)

Execute a Composio tool and return a normalized response dict.

Parameters:

Name Type Description Default
slug str

Tool slug to execute.

required
arguments dict[str, Any]

Arguments for the tool.

required
user_id str | None

Optional user ID.

None
connected_account_id str | None

Optional connected account ID.

None
custom_auth_params dict[str, Any] | None

Optional custom auth params.

None
custom_connection_data dict[str, Any] | None

Optional custom connection data.

None
text str | None

Optional text input.

None
version str | None

Optional version.

None
toolkit_versions Any | None

Optional toolkit versions.

None
modifiers Any | None

Optional modifiers.

None

Returns:

Type Description
dict[str, Any]

dict[str, Any]: Normalized response dict with keys: successful, data, error.

Source code in pyagenity/adapters/tools/composio_adapter.py
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
def execute(
    self,
    *,
    slug: str,
    arguments: dict[str, t.Any],
    user_id: str | None = None,
    connected_account_id: str | None = None,
    custom_auth_params: dict[str, t.Any] | None = None,
    custom_connection_data: dict[str, t.Any] | None = None,
    text: str | None = None,
    version: str | None = None,
    toolkit_versions: t.Any | None = None,
    modifiers: t.Any | None = None,
) -> dict[str, t.Any]:
    """
    Execute a Composio tool and return a normalized response dict.

    Args:
        slug (str): Tool slug to execute.
        arguments (dict[str, Any]): Arguments for the tool.
        user_id (str | None): Optional user ID.
        connected_account_id (str | None): Optional connected account ID.
        custom_auth_params (dict[str, Any] | None): Optional custom auth params.
        custom_connection_data (dict[str, Any] | None): Optional custom connection data.
        text (str | None): Optional text input.
        version (str | None): Optional version.
        toolkit_versions (Any | None): Optional toolkit versions.
        modifiers (Any | None): Optional modifiers.

    Returns:
        dict[str, Any]: Normalized response dict with keys: successful, data, error.
    """
    resp = self._composio.tools.execute(
        slug=slug,
        arguments=arguments,
        user_id=user_id,
        connected_account_id=connected_account_id,
        custom_auth_params=custom_auth_params,
        custom_connection_data=custom_connection_data,
        text=text,
        version=version,
        toolkit_versions=toolkit_versions,
        modifiers=modifiers,
    )

    # The SDK returns a TypedDict-like object; ensure plain dict
    if hasattr(resp, "copy") and not isinstance(resp, dict):  # e.g., TypedDict proxy
        try:
            resp = dict(resp)  # type: ignore[assignment]
        except Exception as exc:
            logger.debug("Could not coerce Composio response to dict: %s", exc)

    # Normalize key presence
    successful = bool(resp.get("successful", False))  # type: ignore[arg-type]
    data = resp.get("data")
    error = resp.get("error")
    return {"successful": successful, "data": data, "error": error}
is_available staticmethod
is_available()

Return True if composio SDK is importable.

Returns:

Name Type Description
bool bool

True if composio SDK is available, False otherwise.

Source code in pyagenity/adapters/tools/composio_adapter.py
83
84
85
86
87
88
89
90
91
@staticmethod
def is_available() -> bool:
    """
    Return True if composio SDK is importable.

    Returns:
        bool: True if composio SDK is available, False otherwise.
    """
    return HAS_COMPOSIO
list_raw_tools_for_llm
list_raw_tools_for_llm(*, tool_slugs=None, toolkits=None, search=None, scopes=None, limit=None)

Return raw Composio tool schemas mapped to function-calling format.

Parameters:

Name Type Description Default
tool_slugs list[str] | None

Optional list of tool slugs.

None
toolkits list[str] | None

Optional list of toolkits.

None
search str | None

Optional search string.

None
scopes list[str] | None

Optional scopes.

None
limit int | None

Optional limit on number of tools.

None

Returns:

Type Description
list[dict[str, Any]]

list[dict[str, Any]]: List of raw tool schemas in function-calling format.

Source code in pyagenity/adapters/tools/composio_adapter.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
def list_raw_tools_for_llm(
    self,
    *,
    tool_slugs: list[str] | None = None,
    toolkits: list[str] | None = None,
    search: str | None = None,
    scopes: list[str] | None = None,
    limit: int | None = None,
) -> list[dict[str, t.Any]]:
    """
    Return raw Composio tool schemas mapped to function-calling format.

    Args:
        tool_slugs (list[str] | None): Optional list of tool slugs.
        toolkits (list[str] | None): Optional list of toolkits.
        search (str | None): Optional search string.
        scopes (list[str] | None): Optional scopes.
        limit (int | None): Optional limit on number of tools.

    Returns:
        list[dict[str, Any]]: List of raw tool schemas in function-calling format.
    """
    formatted: list[dict[str, t.Any]] = []
    raw_tools = self._composio.tools.get_raw_composio_tools(
        tools=tool_slugs, search=search, toolkits=toolkits, scopes=scopes, limit=limit
    )

    for tool in raw_tools:
        try:
            name = tool.slug  # type: ignore[attr-defined]
            description = getattr(tool, "description", "") or "Composio tool"
            params = getattr(tool, "input_parameters", None)
            if not params:
                # Minimal shape if schema missing
                params = {"type": "object", "properties": {}}
            formatted.append(
                {
                    "type": "function",
                    "function": {
                        "name": name,
                        "description": description,
                        "parameters": params,
                    },
                }
            )
        except Exception as e:
            logger.warning("Failed to map Composio tool schema: %s", e)
            continue
    return formatted
list_tools_for_llm
list_tools_for_llm(*, user_id, tool_slugs=None, toolkits=None, search=None, scopes=None, limit=None)

Return tools formatted for LLM function-calling.

Parameters:

Name Type Description Default
user_id str

User ID for tool discovery.

required
tool_slugs list[str] | None

Optional list of tool slugs.

None
toolkits list[str] | None

Optional list of toolkits.

None
search str | None

Optional search string.

None
scopes list[str] | None

Optional scopes.

None
limit int | None

Optional limit on number of tools.

None

Returns:

Type Description
list[dict[str, Any]]

list[dict[str, Any]]: List of tools in function-calling format.

Source code in pyagenity/adapters/tools/composio_adapter.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def list_tools_for_llm(
    self,
    *,
    user_id: str,
    tool_slugs: list[str] | None = None,
    toolkits: list[str] | None = None,
    search: str | None = None,
    scopes: list[str] | None = None,
    limit: int | None = None,
) -> list[dict[str, t.Any]]:
    """
    Return tools formatted for LLM function-calling.

    Args:
        user_id (str): User ID for tool discovery.
        tool_slugs (list[str] | None): Optional list of tool slugs.
        toolkits (list[str] | None): Optional list of toolkits.
        search (str | None): Optional search string.
        scopes (list[str] | None): Optional scopes.
        limit (int | None): Optional limit on number of tools.

    Returns:
        list[dict[str, Any]]: List of tools in function-calling format.
    """
    # Prefer the provider-wrapped format when available
    tools = self._composio.tools.get(
        user_id=user_id,
        tools=tool_slugs,  # type: ignore[arg-type]
        toolkits=toolkits,  # type: ignore[arg-type]
        search=search,
        scopes=scopes,
        limit=limit,
    )

    # The provider-wrapped output may already be in the desired structure.
    # We'll detect and pass-through; otherwise convert using raw schemas.
    formatted: list[dict[str, t.Any]] = []
    for t_obj in tools if isinstance(tools, list) else []:
        try:
            if (
                isinstance(t_obj, dict)
                and t_obj.get("type") == "function"
                and "function" in t_obj
            ):
                formatted.append(t_obj)
            else:
                # Fallback: try to pull minimal fields
                fn = t_obj.get("function", {}) if isinstance(t_obj, dict) else {}
                if fn.get("name") and fn.get("parameters"):
                    formatted.append({"type": "function", "function": fn})
        except Exception as exc:
            logger.debug("Skipping non-conforming Composio tool wrapper: %s", exc)
            continue

    if formatted:
        return formatted

    # Fallback to raw schemas and convert manually
    formatted.extend(
        self.list_raw_tools_for_llm(
            tool_slugs=tool_slugs, toolkits=toolkits, search=search, scopes=scopes, limit=limit
        )
    )

    return formatted
langchain_adapter

LangChain adapter for PyAgenity (generic wrapper, registry-based).

This adapter mirrors the spirit of Google's ADK LangChain wrapper by allowing you to register any LangChain tool (BaseTool/StructuredTool) or a duck-typed object that exposes a run/_run method, then exposing it to PyAgenity in the uniform function-calling schema that ToolNode expects.

Key points: - Register arbitrary tools at runtime via register_tool / register_tools. - Tool schemas are derived from tool.args (when available) or inferred from the tool's pydantic args_schema; otherwise, we fallback to a minimal best-effort schema inferred from the wrapped function signature. - Execution prefers invoke (Runnable interface) and falls back to run/ _run or calling a wrapped function with kwargs.

Optional install

pip install pyagenity[langchain]

Backward-compat convenience: - For continuity with prior versions, the adapter can auto-register two common tools (tavily_search and requests_get) if autoload_default_tools is True and no user-registered tools exist. You can disable this by passing autoload_default_tools=False to the constructor.

Classes:

Name Description
LangChainAdapter

Generic registry-based LangChain adapter.

LangChainToolWrapper

Wrap a LangChain tool or a duck-typed tool into a uniform interface.

Attributes:

Name Type Description
HAS_LANGCHAIN
logger
Attributes
HAS_LANGCHAIN module-attribute
HAS_LANGCHAIN = find_spec('langchain_core') is not None
logger module-attribute
logger = getLogger(__name__)
Classes
LangChainAdapter

Generic registry-based LangChain adapter.

Notes
  • Avoids importing heavy integrations until needed (lazy default autoload).
  • Normalizes schemas and execution results into simple dicts.
  • Allows arbitrary tool registration instead of hardcoding a tiny set.

Methods:

Name Description
__init__

Initialize LangChainAdapter.

execute

Execute a supported LangChain tool and normalize the response.

is_available

Return True if langchain-core is importable.

list_tools_for_llm

Return a list of function-calling formatted tool schemas.

register_tool

Register a tool instance and return the resolved name used for exposure.

register_tools

Register multiple tool instances.

Source code in pyagenity/adapters/tools/langchain_adapter.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
class LangChainAdapter:
    """
    Generic registry-based LangChain adapter.

    Notes:
        - Avoids importing heavy integrations until needed (lazy default autoload).
        - Normalizes schemas and execution results into simple dicts.
        - Allows arbitrary tool registration instead of hardcoding a tiny set.
    """

    def __init__(self, *, autoload_default_tools: bool = True) -> None:
        """
        Initialize LangChainAdapter.

        Args:
            autoload_default_tools (bool): Whether to autoload default tools if registry is empty.

        Raises:
            ImportError: If langchain-core is not installed.
        """
        if not HAS_LANGCHAIN:
            raise ImportError(
                "LangChainAdapter requires 'langchain-core' and optional integrations.\n"
                "Install with: pip install pyagenity[langchain]"
            )
        self._registry: dict[str, LangChainToolWrapper] = {}
        self._autoload = autoload_default_tools

    @staticmethod
    def is_available() -> bool:
        """
        Return True if langchain-core is importable.

        Returns:
            bool: True if langchain-core is available, False otherwise.
        """
        return HAS_LANGCHAIN

    # ------------------------
    # Discovery
    # ------------------------
    def list_tools_for_llm(self) -> list[dict[str, t.Any]]:
        """
        Return a list of function-calling formatted tool schemas.

        If registry is empty and autoload is enabled, attempt to autoload a
        couple of common tools for convenience (tavily_search, requests_get).

        Returns:
            list[dict[str, Any]]: List of tool schemas in function-calling format.
        """
        if not self._registry and self._autoload:
            self._try_autoload_defaults()

        return [wrapper.to_schema() for wrapper in self._registry.values()]

    # ------------------------
    # Execute
    # ------------------------
    def execute(self, *, name: str, arguments: dict[str, t.Any]) -> dict[str, t.Any]:
        """
        Execute a supported LangChain tool and normalize the response.

        Args:
            name (str): Name of the tool to execute.
            arguments (dict[str, Any]): Arguments for the tool.

        Returns:
            dict[str, Any]: Normalized response dict with keys: successful, data, error.
        """
        if name not in self._registry and self._autoload:
            # Late autoload attempt in case discovery wasn't called first
            self._try_autoload_defaults()

        wrapper = self._registry.get(name)
        if not wrapper:
            return {"successful": False, "data": None, "error": f"Unknown LangChain tool: {name}"}
        return wrapper.execute(arguments)

    # ------------------------
    # Internals
    # ------------------------
    def register_tool(
        self,
        tool: t.Any,
        *,
        name: str | None = None,
        description: str | None = None,
    ) -> str:
        """
        Register a tool instance and return the resolved name used for exposure.

        Args:
            tool (Any): Tool instance to register.
            name (str | None): Optional override for tool name.
            description (str | None): Optional override for tool description.

        Returns:
            str: The resolved name used for exposure.
        """
        wrapper = LangChainToolWrapper(tool, name=name, description=description)
        self._registry[wrapper.name] = wrapper
        return wrapper.name

    def register_tools(self, tools: list[t.Any]) -> list[str]:
        """
        Register multiple tool instances.

        Args:
            tools (list[Any]): List of tool instances to register.

        Returns:
            list[str]: List of resolved names for the registered tools.
        """
        names: list[str] = []
        for tool in tools:
            names.append(self.register_tool(tool))
        return names

    def _create_tavily_search_tool(self) -> t.Any:
        """
        Construct Tavily search tool lazily.

        Prefer the new dedicated integration `langchain_tavily.TavilySearch`.
        Fall back to the deprecated community tool if needed.

        Returns:
            Any: Tavily search tool instance.

        Raises:
            ImportError: If Tavily tool cannot be imported.
        """
        # Preferred: langchain-tavily
        try:
            mod = importlib.import_module("langchain_tavily")
            return mod.TavilySearch()  # type: ignore[attr-defined]
        except Exception as exc:
            logger.debug("Preferred langchain_tavily import failed: %s", exc)

        # Fallback: deprecated community tool (still functional for now)
        try:
            mod = importlib.import_module("langchain_community.tools.tavily_search")
            return mod.TavilySearchResults()
        except Exception as exc:  # ImportError or runtime
            raise ImportError(
                "Tavily tool requires 'langchain-tavily' (preferred) or"
                " 'langchain-community' with 'tavily-python'.\n"
                "Install with: pip install pyagenity[langchain]"
            ) from exc

    def _create_requests_get_tool(self) -> t.Any:
        """
        Construct RequestsGetTool lazily with a basic requests wrapper.

        Note: Requests tools require an explicit wrapper instance and, for safety,
        default to disallowing dangerous requests. Here we opt-in to allow GET
        requests by setting allow_dangerous_requests=True to make the tool usable
        in agent contexts. Consider tightening this in your application.

        Returns:
            Any: RequestsGetTool instance.

        Raises:
            ImportError: If RequestsGetTool cannot be imported.
        """
        try:
            req_tool_mod = importlib.import_module("langchain_community.tools.requests.tool")
            util_mod = importlib.import_module("langchain_community.utilities.requests")
            wrapper = util_mod.TextRequestsWrapper(headers={})  # type: ignore[attr-defined]
            return req_tool_mod.RequestsGetTool(
                requests_wrapper=wrapper,
                allow_dangerous_requests=True,
            )
        except Exception as exc:  # ImportError or runtime
            raise ImportError(
                "Requests tool requires 'langchain-community'.\n"
                "Install with: pip install pyagenity[langchain]"
            ) from exc

    def _try_autoload_defaults(self) -> None:
        """
        Best-effort autoload of a couple of common tools.

        This keeps prior behavior available while allowing users to register
        arbitrary tools. Failures are logged but non-fatal.

        Returns:
            None
        """
        # Tavily search
        try:
            tavily = self._create_tavily_search_tool()
            self.register_tool(tavily, name="tavily_search")
        except Exception as exc:
            logger.debug("Skipping Tavily autoload: %s", exc)

        # Requests GET
        try:
            rget = self._create_requests_get_tool()
            self.register_tool(rget, name="requests_get")
        except Exception as exc:
            logger.debug("Skipping requests_get autoload: %s", exc)
Functions
__init__
__init__(*, autoload_default_tools=True)

Initialize LangChainAdapter.

Parameters:

Name Type Description Default
autoload_default_tools bool

Whether to autoload default tools if registry is empty.

True

Raises:

Type Description
ImportError

If langchain-core is not installed.

Source code in pyagenity/adapters/tools/langchain_adapter.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
def __init__(self, *, autoload_default_tools: bool = True) -> None:
    """
    Initialize LangChainAdapter.

    Args:
        autoload_default_tools (bool): Whether to autoload default tools if registry is empty.

    Raises:
        ImportError: If langchain-core is not installed.
    """
    if not HAS_LANGCHAIN:
        raise ImportError(
            "LangChainAdapter requires 'langchain-core' and optional integrations.\n"
            "Install with: pip install pyagenity[langchain]"
        )
    self._registry: dict[str, LangChainToolWrapper] = {}
    self._autoload = autoload_default_tools
execute
execute(*, name, arguments)

Execute a supported LangChain tool and normalize the response.

Parameters:

Name Type Description Default
name str

Name of the tool to execute.

required
arguments dict[str, Any]

Arguments for the tool.

required

Returns:

Type Description
dict[str, Any]

dict[str, Any]: Normalized response dict with keys: successful, data, error.

Source code in pyagenity/adapters/tools/langchain_adapter.py
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
def execute(self, *, name: str, arguments: dict[str, t.Any]) -> dict[str, t.Any]:
    """
    Execute a supported LangChain tool and normalize the response.

    Args:
        name (str): Name of the tool to execute.
        arguments (dict[str, Any]): Arguments for the tool.

    Returns:
        dict[str, Any]: Normalized response dict with keys: successful, data, error.
    """
    if name not in self._registry and self._autoload:
        # Late autoload attempt in case discovery wasn't called first
        self._try_autoload_defaults()

    wrapper = self._registry.get(name)
    if not wrapper:
        return {"successful": False, "data": None, "error": f"Unknown LangChain tool: {name}"}
    return wrapper.execute(arguments)
is_available staticmethod
is_available()

Return True if langchain-core is importable.

Returns:

Name Type Description
bool bool

True if langchain-core is available, False otherwise.

Source code in pyagenity/adapters/tools/langchain_adapter.py
257
258
259
260
261
262
263
264
265
@staticmethod
def is_available() -> bool:
    """
    Return True if langchain-core is importable.

    Returns:
        bool: True if langchain-core is available, False otherwise.
    """
    return HAS_LANGCHAIN
list_tools_for_llm
list_tools_for_llm()

Return a list of function-calling formatted tool schemas.

If registry is empty and autoload is enabled, attempt to autoload a couple of common tools for convenience (tavily_search, requests_get).

Returns:

Type Description
list[dict[str, Any]]

list[dict[str, Any]]: List of tool schemas in function-calling format.

Source code in pyagenity/adapters/tools/langchain_adapter.py
270
271
272
273
274
275
276
277
278
279
280
281
282
283
def list_tools_for_llm(self) -> list[dict[str, t.Any]]:
    """
    Return a list of function-calling formatted tool schemas.

    If registry is empty and autoload is enabled, attempt to autoload a
    couple of common tools for convenience (tavily_search, requests_get).

    Returns:
        list[dict[str, Any]]: List of tool schemas in function-calling format.
    """
    if not self._registry and self._autoload:
        self._try_autoload_defaults()

    return [wrapper.to_schema() for wrapper in self._registry.values()]
register_tool
register_tool(tool, *, name=None, description=None)

Register a tool instance and return the resolved name used for exposure.

Parameters:

Name Type Description Default
tool Any

Tool instance to register.

required
name str | None

Optional override for tool name.

None
description str | None

Optional override for tool description.

None

Returns:

Name Type Description
str str

The resolved name used for exposure.

Source code in pyagenity/adapters/tools/langchain_adapter.py
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
def register_tool(
    self,
    tool: t.Any,
    *,
    name: str | None = None,
    description: str | None = None,
) -> str:
    """
    Register a tool instance and return the resolved name used for exposure.

    Args:
        tool (Any): Tool instance to register.
        name (str | None): Optional override for tool name.
        description (str | None): Optional override for tool description.

    Returns:
        str: The resolved name used for exposure.
    """
    wrapper = LangChainToolWrapper(tool, name=name, description=description)
    self._registry[wrapper.name] = wrapper
    return wrapper.name
register_tools
register_tools(tools)

Register multiple tool instances.

Parameters:

Name Type Description Default
tools list[Any]

List of tool instances to register.

required

Returns:

Type Description
list[str]

list[str]: List of resolved names for the registered tools.

Source code in pyagenity/adapters/tools/langchain_adapter.py
333
334
335
336
337
338
339
340
341
342
343
344
345
346
def register_tools(self, tools: list[t.Any]) -> list[str]:
    """
    Register multiple tool instances.

    Args:
        tools (list[Any]): List of tool instances to register.

    Returns:
        list[str]: List of resolved names for the registered tools.
    """
    names: list[str] = []
    for tool in tools:
        names.append(self.register_tool(tool))
    return names
LangChainToolWrapper

Wrap a LangChain tool or a duck-typed tool into a uniform interface.

Responsibilities
  • Resolve execution entrypoint (invoke/run/_run/callable func)
  • Provide a function-calling schema {name, description, parameters}
  • Execute with dict arguments and return a JSON-serializable result

Methods:

Name Description
__init__

Initialize LangChainToolWrapper.

execute

Execute the wrapped tool with the provided arguments.

to_schema

Return the function-calling schema for the wrapped tool.

Attributes:

Name Type Description
description
name
Source code in pyagenity/adapters/tools/langchain_adapter.py
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
class LangChainToolWrapper:
    """
    Wrap a LangChain tool or a duck-typed tool into a uniform interface.

    Responsibilities:
        - Resolve execution entrypoint (invoke/run/_run/callable func)
        - Provide a function-calling schema {name, description, parameters}
        - Execute with dict arguments and return a JSON-serializable result
    """

    def __init__(
        self,
        tool: t.Any,
        *,
        name: str | None = None,
        description: str | None = None,
    ) -> None:
        """
        Initialize LangChainToolWrapper.

        Args:
            tool (Any): The LangChain tool or duck-typed object to wrap.
            name (str | None): Optional override for tool name.
            description (str | None): Optional override for tool description.
        """
        self._tool = tool
        self.name = name or getattr(tool, "name", None) or self._default_name(tool)
        self.description = (
            description
            or getattr(tool, "description", None)
            or f"LangChain tool wrapper for {type(tool).__name__}"
        )
        self._callable = self._resolve_callable(tool)

    @staticmethod
    def _default_name(tool: t.Any) -> str:
        # Prefer class name in snake_case-ish
        cls = type(tool).__name__
        return cls[0].lower() + "".join((c if c.islower() else f"_{c.lower()}") for c in cls[1:])

    @staticmethod
    def _resolve_callable(tool: t.Any) -> t.Callable[..., t.Any] | None:
        # Try StructuredTool.func or coroutine
        try:
            # Avoid importing StructuredTool; duck-type attributes
            if getattr(tool, "func", None) is not None:
                return t.cast(t.Callable[..., t.Any], tool.func)
            if getattr(tool, "coroutine", None) is not None:
                return t.cast(t.Callable[..., t.Any], tool.coroutine)
        except Exception as exc:  # pragma: no cover - defensive
            logger.debug("Ignoring tool callable resolution error: %s", exc)
        # Fallback to run/_run methods as callables
        if hasattr(tool, "_run"):
            return tool._run  # type: ignore[attr-defined]
        if hasattr(tool, "run"):
            return tool.run  # type: ignore[attr-defined]
        # Nothing callable to directly use; rely on invoke/run on execution
        return None

    def _json_schema_from_args_schema(self) -> dict[str, t.Any] | None:
        # LangChain BaseTool typically provides .args (already JSON schema)
        schema = getattr(self._tool, "args", None)
        if isinstance(schema, dict) and schema.get("type") == "object":
            return schema

        # Try args_schema (pydantic v1 or v2)
        args_schema = getattr(self._tool, "args_schema", None)
        if args_schema is None:
            return None
        try:
            # pydantic v2
            if hasattr(args_schema, "model_json_schema"):
                js = args_schema.model_json_schema()  # type: ignore[attr-defined]
            else:  # pydantic v1
                js = args_schema.schema()  # type: ignore[attr-defined]
            # Convert typical pydantic schema to a plain "type: object" with properties
            # Look for properties directly
            props = js.get("properties") or {}
            required = js.get("required") or []
            return {"type": "object", "properties": props, "required": required}
        except Exception:  # pragma: no cover - be tolerant
            return None

    def _infer_schema_from_signature(self) -> dict[str, t.Any]:
        func = self._callable or getattr(self._tool, "invoke", None)
        if func is None or not callable(func):  # last resort empty schema
            return {"type": "object", "properties": {}}

        try:
            sig = inspect.signature(func)
            properties: dict[str, dict[str, t.Any]] = {}
            required: list[str] = []
            for name, param in sig.parameters.items():
                if name in {"self", "run_manager", "config", "callbacks"}:
                    continue
                ann = param.annotation
                json_type: str | None = None
                if ann is not inspect._empty:  # type: ignore[attr-defined]
                    json_type = self._map_annotation_to_json_type(ann)
                prop: dict[str, t.Any] = {}
                if json_type:
                    prop["type"] = json_type
                if param.default is inspect._empty:  # type: ignore[attr-defined]
                    required.append(name)
                properties[name] = prop
            schema: dict[str, t.Any] = {"type": "object", "properties": properties}
            if required:
                schema["required"] = required
            return schema
        except Exception:
            return {"type": "object", "properties": {}}

    @staticmethod
    def _map_annotation_to_json_type(ann: t.Any) -> str | None:
        try:
            origin = t.get_origin(ann) or ann
            mapping = {
                str: "string",
                int: "integer",
                float: "number",
                bool: "boolean",
                list: "array",
                tuple: "array",
                set: "array",
                dict: "object",
            }
            # Typed containers map to base Python containers in get_origin
            return mapping.get(origin)
        except Exception:
            return None

    def to_schema(self) -> dict[str, t.Any]:
        """
        Return the function-calling schema for the wrapped tool.

        Returns:
            dict[str, Any]: Function-calling schema with name, description, parameters.
        """
        schema = self._json_schema_from_args_schema() or self._infer_schema_from_signature()
        return {
            "type": "function",
            "function": {
                "name": self.name,
                "description": self.description,
                "parameters": schema,
            },
        }

    def execute(self, arguments: dict[str, t.Any]) -> dict[str, t.Any]:
        """
        Execute the wrapped tool with the provided arguments.

        Args:
            arguments (dict[str, Any]): Arguments to pass to the tool.

        Returns:
            dict[str, Any]: Normalized response dict with keys: successful, data, error.
        """
        try:
            tool = self._tool
            if hasattr(tool, "invoke"):
                result = tool.invoke(arguments)  # type: ignore[misc]
            elif hasattr(tool, "run"):
                result = tool.run(arguments)  # type: ignore[misc]
            elif hasattr(tool, "_run"):
                result = tool._run(arguments)  # type: ignore[attr-defined]
            elif callable(self._callable):
                result = self._callable(**arguments)  # type: ignore[call-arg]
            else:
                raise AttributeError("Tool does not support invoke/run/_run/callable")

            data: t.Any = result
            if not isinstance(result, str | int | float | bool | type(None) | dict | list):
                try:
                    json.dumps(result)
                except Exception:
                    data = str(result)
            return {"successful": True, "data": data, "error": None}
        except Exception as exc:
            logger.error("LangChain wrapped tool '%s' failed: %s", self.name, exc)
            return {"successful": False, "data": None, "error": str(exc)}
Attributes
description instance-attribute
description = description or getattr(tool, 'description', None) or f'LangChain tool wrapper for {__name__}'
name instance-attribute
name = name or getattr(tool, 'name', None) or _default_name(tool)
Functions
__init__
__init__(tool, *, name=None, description=None)

Initialize LangChainToolWrapper.

Parameters:

Name Type Description Default
tool Any

The LangChain tool or duck-typed object to wrap.

required
name str | None

Optional override for tool name.

None
description str | None

Optional override for tool description.

None
Source code in pyagenity/adapters/tools/langchain_adapter.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def __init__(
    self,
    tool: t.Any,
    *,
    name: str | None = None,
    description: str | None = None,
) -> None:
    """
    Initialize LangChainToolWrapper.

    Args:
        tool (Any): The LangChain tool or duck-typed object to wrap.
        name (str | None): Optional override for tool name.
        description (str | None): Optional override for tool description.
    """
    self._tool = tool
    self.name = name or getattr(tool, "name", None) or self._default_name(tool)
    self.description = (
        description
        or getattr(tool, "description", None)
        or f"LangChain tool wrapper for {type(tool).__name__}"
    )
    self._callable = self._resolve_callable(tool)
execute
execute(arguments)

Execute the wrapped tool with the provided arguments.

Parameters:

Name Type Description Default
arguments dict[str, Any]

Arguments to pass to the tool.

required

Returns:

Type Description
dict[str, Any]

dict[str, Any]: Normalized response dict with keys: successful, data, error.

Source code in pyagenity/adapters/tools/langchain_adapter.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def execute(self, arguments: dict[str, t.Any]) -> dict[str, t.Any]:
    """
    Execute the wrapped tool with the provided arguments.

    Args:
        arguments (dict[str, Any]): Arguments to pass to the tool.

    Returns:
        dict[str, Any]: Normalized response dict with keys: successful, data, error.
    """
    try:
        tool = self._tool
        if hasattr(tool, "invoke"):
            result = tool.invoke(arguments)  # type: ignore[misc]
        elif hasattr(tool, "run"):
            result = tool.run(arguments)  # type: ignore[misc]
        elif hasattr(tool, "_run"):
            result = tool._run(arguments)  # type: ignore[attr-defined]
        elif callable(self._callable):
            result = self._callable(**arguments)  # type: ignore[call-arg]
        else:
            raise AttributeError("Tool does not support invoke/run/_run/callable")

        data: t.Any = result
        if not isinstance(result, str | int | float | bool | type(None) | dict | list):
            try:
                json.dumps(result)
            except Exception:
                data = str(result)
        return {"successful": True, "data": data, "error": None}
    except Exception as exc:
        logger.error("LangChain wrapped tool '%s' failed: %s", self.name, exc)
        return {"successful": False, "data": None, "error": str(exc)}
to_schema
to_schema()

Return the function-calling schema for the wrapped tool.

Returns:

Type Description
dict[str, Any]

dict[str, Any]: Function-calling schema with name, description, parameters.

Source code in pyagenity/adapters/tools/langchain_adapter.py
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def to_schema(self) -> dict[str, t.Any]:
    """
    Return the function-calling schema for the wrapped tool.

    Returns:
        dict[str, Any]: Function-calling schema with name, description, parameters.
    """
    schema = self._json_schema_from_args_schema() or self._infer_schema_from_signature()
    return {
        "type": "function",
        "function": {
            "name": self.name,
            "description": self.description,
            "parameters": schema,
        },
    }

checkpointer

Checkpointer adapters for agent state persistence in PyAgenity.

This module exposes unified checkpointing interfaces for agent graphs, supporting in-memory and Postgres-backed persistence. PgCheckpointer is only exported if its dependencies (asyncpg, redis) are available.

Exports

BaseCheckpointer: Abstract base class for checkpointing implementations. InMemoryCheckpointer: In-memory checkpointing for development/testing. PgCheckpointer: Postgres+Redis checkpointing (optional, requires extras).

Usage

PgCheckpointer requires: pip install pyagenity[pg_checkpoint]

Modules:

Name Description
base_checkpointer
in_memory_checkpointer
pg_checkpointer

Classes:

Name Description
BaseCheckpointer

Abstract base class for checkpointing agent state, messages, and threads.

InMemoryCheckpointer

In-memory implementation of BaseCheckpointer.

PgCheckpointer

Implements a checkpointer using PostgreSQL and Redis for persistent and cached state management.

Attributes

__all__ module-attribute
__all__ = ['BaseCheckpointer', 'InMemoryCheckpointer', 'PgCheckpointer']

Classes

BaseCheckpointer

Bases: ABC

Abstract base class for checkpointing agent state, messages, and threads.

This class defines the contract for all checkpointer implementations, supporting both async and sync methods. Subclasses should implement async methods for optimal performance. Sync methods are provided for compatibility.

Usage
  • Async-first design: subclasses should implement async def methods.
  • If a subclass provides only a sync def, it will be executed in a worker thread automatically using asyncio.run.
  • Callers always use the async APIs (await cp.put_state(...), etc.).

Class Type Parameters:

Name Bound or Constraints Description Default
StateT AgentState

Type of agent state (must inherit from AgentState).

required

Methods:

Name Description
aclean_thread

Clean/delete thread asynchronously.

aclear_state

Clear agent state asynchronously.

adelete_message

Delete a specific message asynchronously.

aget_message

Retrieve a specific message asynchronously.

aget_state

Retrieve agent state asynchronously.

aget_state_cache

Retrieve agent state from cache asynchronously.

aget_thread

Retrieve thread info asynchronously.

alist_messages

List messages asynchronously with optional filtering.

alist_threads

List threads asynchronously with optional filtering.

aput_messages

Store messages asynchronously.

aput_state

Store agent state asynchronously.

aput_state_cache

Store agent state in cache asynchronously.

aput_thread

Store thread info asynchronously.

arelease

Release resources asynchronously.

asetup

Asynchronous setup method for checkpointer.

clean_thread

Clean/delete thread synchronously.

clear_state

Clear agent state synchronously.

delete_message

Delete a specific message synchronously.

get_message

Retrieve a specific message synchronously.

get_state

Retrieve agent state synchronously.

get_state_cache

Retrieve agent state from cache synchronously.

get_thread

Retrieve thread info synchronously.

list_messages

List messages synchronously with optional filtering.

list_threads

List threads synchronously with optional filtering.

put_messages

Store messages synchronously.

put_state

Store agent state synchronously.

put_state_cache

Store agent state in cache synchronously.

put_thread

Store thread info synchronously.

release

Release resources synchronously.

setup

Synchronous setup method for checkpointer.

Source code in pyagenity/checkpointer/base_checkpointer.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
class BaseCheckpointer[StateT: AgentState](ABC):
    """
    Abstract base class for checkpointing agent state, messages, and threads.

    This class defines the contract for all checkpointer implementations, supporting both
    async and sync methods.
    Subclasses should implement async methods for optimal performance.
    Sync methods are provided for compatibility.

    Usage:
        - Async-first design: subclasses should implement `async def` methods.
        - If a subclass provides only a sync `def`, it will be executed in a worker thread
            automatically using `asyncio.run`.
        - Callers always use the async APIs (`await cp.put_state(...)`, etc.).

    Type Args:
        StateT: Type of agent state (must inherit from AgentState).
    """

    ###########################
    #### SETUP ################
    ###########################
    def setup(self) -> Any:
        """
        Synchronous setup method for checkpointer.

        Returns:
            Any: Implementation-defined setup result.
        """
        return run_coroutine(self.asetup())

    @abstractmethod
    async def asetup(self) -> Any:
        """
        Asynchronous setup method for checkpointer.

        Returns:
            Any: Implementation-defined setup result.
        """
        raise NotImplementedError

    # -------------------------
    # State methods Async
    # -------------------------
    @abstractmethod
    async def aput_state(self, config: dict[str, Any], state: StateT) -> StateT:
        """
        Store agent state asynchronously.

        Args:
            config (dict): Configuration dictionary.
            state (StateT): State object to store.

        Returns:
            StateT: The stored state object.
        """
        raise NotImplementedError

    @abstractmethod
    async def aget_state(self, config: dict[str, Any]) -> StateT | None:
        """
        Retrieve agent state asynchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            StateT | None: Retrieved state or None.
        """
        raise NotImplementedError

    @abstractmethod
    async def aclear_state(self, config: dict[str, Any]) -> Any:
        """
        Clear agent state asynchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            Any: Implementation-defined result.
        """
        raise NotImplementedError

    @abstractmethod
    async def aput_state_cache(self, config: dict[str, Any], state: StateT) -> Any | None:
        """
        Store agent state in cache asynchronously.

        Args:
            config (dict): Configuration dictionary.
            state (StateT): State object to cache.

        Returns:
            Any | None: Implementation-defined result.
        """
        raise NotImplementedError

    @abstractmethod
    async def aget_state_cache(self, config: dict[str, Any]) -> StateT | None:
        """
        Retrieve agent state from cache asynchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            StateT | None: Cached state or None.
        """
        raise NotImplementedError

    # -------------------------
    # State methods Sync
    # -------------------------
    def put_state(self, config: dict[str, Any], state: StateT) -> StateT:
        """
        Store agent state synchronously.

        Args:
            config (dict): Configuration dictionary.
            state (StateT): State object to store.

        Returns:
            StateT: The stored state object.
        """
        return run_coroutine(self.aput_state(config, state))

    def get_state(self, config: dict[str, Any]) -> StateT | None:
        """
        Retrieve agent state synchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            StateT | None: Retrieved state or None.
        """
        return run_coroutine(self.aget_state(config))

    def clear_state(self, config: dict[str, Any]) -> Any:
        """
        Clear agent state synchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            Any: Implementation-defined result.
        """
        return run_coroutine(self.aclear_state(config))

    def put_state_cache(self, config: dict[str, Any], state: StateT) -> Any | None:
        """
        Store agent state in cache synchronously.

        Args:
            config (dict): Configuration dictionary.
            state (StateT): State object to cache.

        Returns:
            Any | None: Implementation-defined result.
        """
        return run_coroutine(self.aput_state_cache(config, state))

    def get_state_cache(self, config: dict[str, Any]) -> StateT | None:
        """
        Retrieve agent state from cache synchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            StateT | None: Cached state or None.
        """
        return run_coroutine(self.aget_state_cache(config))

    # -------------------------
    # Message methods async
    # -------------------------
    @abstractmethod
    async def aput_messages(
        self,
        config: dict[str, Any],
        messages: list[Message],
        metadata: dict[str, Any] | None = None,
    ) -> Any:
        """
        Store messages asynchronously.

        Args:
            config (dict): Configuration dictionary.
            messages (list[Message]): List of messages to store.
            metadata (dict, optional): Additional metadata.

        Returns:
            Any: Implementation-defined result.
        """
        raise NotImplementedError

    @abstractmethod
    async def aget_message(self, config: dict[str, Any], message_id: str | int) -> Message:
        """
        Retrieve a specific message asynchronously.

        Args:
            config (dict): Configuration dictionary.
            message_id (str|int): Message identifier.

        Returns:
            Message: Retrieved message object.
        """
        raise NotImplementedError

    @abstractmethod
    async def alist_messages(
        self,
        config: dict[str, Any],
        search: str | None = None,
        offset: int | None = None,
        limit: int | None = None,
    ) -> list[Message]:
        """
        List messages asynchronously with optional filtering.

        Args:
            config (dict): Configuration dictionary.
            search (str, optional): Search string.
            offset (int, optional): Offset for pagination.
            limit (int, optional): Limit for pagination.

        Returns:
            list[Message]: List of message objects.
        """
        raise NotImplementedError

    @abstractmethod
    async def adelete_message(self, config: dict[str, Any], message_id: str | int) -> Any | None:
        """
        Delete a specific message asynchronously.

        Args:
            config (dict): Configuration dictionary.
            message_id (str|int): Message identifier.

        Returns:
            Any | None: Implementation-defined result.
        """
        raise NotImplementedError

    # -------------------------
    # Message methods sync
    # -------------------------
    def put_messages(
        self,
        config: dict[str, Any],
        messages: list[Message],
        metadata: dict[str, Any] | None = None,
    ) -> Any:
        """
        Store messages synchronously.

        Args:
            config (dict): Configuration dictionary.
            messages (list[Message]): List of messages to store.
            metadata (dict, optional): Additional metadata.

        Returns:
            Any: Implementation-defined result.
        """
        return run_coroutine(self.aput_messages(config, messages, metadata))

    def get_message(self, config: dict[str, Any], message_id: str | int) -> Message:
        """
        Retrieve a specific message synchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            Message: Retrieved message object.
        """
        return run_coroutine(self.aget_message(config, message_id))

    def list_messages(
        self,
        config: dict[str, Any],
        search: str | None = None,
        offset: int | None = None,
        limit: int | None = None,
    ) -> list[Message]:
        """
        List messages synchronously with optional filtering.

        Args:
            config (dict): Configuration dictionary.
            search (str, optional): Search string.
            offset (int, optional): Offset for pagination.
            limit (int, optional): Limit for pagination.

        Returns:
            list[Message]: List of message objects.
        """
        return run_coroutine(self.alist_messages(config, search, offset, limit))

    def delete_message(self, config: dict[str, Any], message_id: str | int) -> Any | None:
        """
        Delete a specific message synchronously.

        Args:
            config (dict): Configuration dictionary.
            message_id (str|int): Message identifier.

        Returns:
            Any | None: Implementation-defined result.
        """
        return run_coroutine(self.adelete_message(config, message_id))

    # -------------------------
    # Thread methods async
    # -------------------------
    @abstractmethod
    async def aput_thread(
        self,
        config: dict[str, Any],
        thread_info: ThreadInfo,
    ) -> Any | None:
        """
        Store thread info asynchronously.

        Args:
            config (dict): Configuration dictionary.
            thread_info (ThreadInfo): Thread information object.

        Returns:
            Any | None: Implementation-defined result.
        """
        raise NotImplementedError

    @abstractmethod
    async def aget_thread(
        self,
        config: dict[str, Any],
    ) -> ThreadInfo | None:
        """
        Retrieve thread info asynchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            ThreadInfo | None: Thread information object or None.
        """
        raise NotImplementedError

    @abstractmethod
    async def alist_threads(
        self,
        config: dict[str, Any],
        search: str | None = None,
        offset: int | None = None,
        limit: int | None = None,
    ) -> list[ThreadInfo]:
        """
        List threads asynchronously with optional filtering.

        Args:
            config (dict): Configuration dictionary.
            search (str, optional): Search string.
            offset (int, optional): Offset for pagination.
            limit (int, optional): Limit for pagination.

        Returns:
            list[ThreadInfo]: List of thread information objects.
        """
        raise NotImplementedError

    @abstractmethod
    async def aclean_thread(self, config: dict[str, Any]) -> Any | None:
        """
        Clean/delete thread asynchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            Any | None: Implementation-defined result.
        """
        raise NotImplementedError

    # -------------------------
    # Thread methods sync
    # -------------------------
    def put_thread(self, config: dict[str, Any], thread_info: ThreadInfo) -> Any | None:
        """
        Store thread info synchronously.

        Args:
            config (dict): Configuration dictionary.
            thread_info (ThreadInfo): Thread information object.

        Returns:
            Any | None: Implementation-defined result.
        """
        return run_coroutine(self.aput_thread(config, thread_info))

    def get_thread(self, config: dict[str, Any]) -> ThreadInfo | None:
        """
        Retrieve thread info synchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            ThreadInfo | None: Thread information object or None.
        """
        return run_coroutine(self.aget_thread(config))

    def list_threads(
        self,
        config: dict[str, Any],
        search: str | None = None,
        offset: int | None = None,
        limit: int | None = None,
    ) -> list[ThreadInfo]:
        """
        List threads synchronously with optional filtering.

        Args:
            config (dict): Configuration dictionary.
            search (str, optional): Search string.
            offset (int, optional): Offset for pagination.
            limit (int, optional): Limit for pagination.

        Returns:
            list[ThreadInfo]: List of thread information objects.
        """
        return run_coroutine(self.alist_threads(config, search, offset, limit))

    def clean_thread(self, config: dict[str, Any]) -> Any | None:
        """
        Clean/delete thread synchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            Any | None: Implementation-defined result.
        """
        return run_coroutine(self.aclean_thread(config))

    # -------------------------
    # Clean Resources
    # -------------------------
    def release(self) -> Any | None:
        """
        Release resources synchronously.

        Returns:
            Any | None: Implementation-defined result.
        """
        return run_coroutine(self.arelease())

    @abstractmethod
    async def arelease(self) -> Any | None:
        """
        Release resources asynchronously.

        Returns:
            Any | None: Implementation-defined result.
        """
        raise NotImplementedError
Functions
aclean_thread abstractmethod async
aclean_thread(config)

Clean/delete thread asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
396
397
398
399
400
401
402
403
404
405
406
407
@abstractmethod
async def aclean_thread(self, config: dict[str, Any]) -> Any | None:
    """
    Clean/delete thread asynchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        Any | None: Implementation-defined result.
    """
    raise NotImplementedError
aclear_state abstractmethod async
aclear_state(config)

Clear agent state asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Name Type Description
Any Any

Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
@abstractmethod
async def aclear_state(self, config: dict[str, Any]) -> Any:
    """
    Clear agent state asynchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        Any: Implementation-defined result.
    """
    raise NotImplementedError
adelete_message abstractmethod async
adelete_message(config, message_id)

Delete a specific message asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
message_id str | int

Message identifier.

required

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
255
256
257
258
259
260
261
262
263
264
265
266
267
@abstractmethod
async def adelete_message(self, config: dict[str, Any], message_id: str | int) -> Any | None:
    """
    Delete a specific message asynchronously.

    Args:
        config (dict): Configuration dictionary.
        message_id (str|int): Message identifier.

    Returns:
        Any | None: Implementation-defined result.
    """
    raise NotImplementedError
aget_message abstractmethod async
aget_message(config, message_id)

Retrieve a specific message asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
message_id str | int

Message identifier.

required

Returns:

Name Type Description
Message Message

Retrieved message object.

Source code in pyagenity/checkpointer/base_checkpointer.py
219
220
221
222
223
224
225
226
227
228
229
230
231
@abstractmethod
async def aget_message(self, config: dict[str, Any], message_id: str | int) -> Message:
    """
    Retrieve a specific message asynchronously.

    Args:
        config (dict): Configuration dictionary.
        message_id (str|int): Message identifier.

    Returns:
        Message: Retrieved message object.
    """
    raise NotImplementedError
aget_state abstractmethod async
aget_state(config)

Retrieve agent state asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
StateT | None

StateT | None: Retrieved state or None.

Source code in pyagenity/checkpointer/base_checkpointer.py
78
79
80
81
82
83
84
85
86
87
88
89
@abstractmethod
async def aget_state(self, config: dict[str, Any]) -> StateT | None:
    """
    Retrieve agent state asynchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        StateT | None: Retrieved state or None.
    """
    raise NotImplementedError
aget_state_cache abstractmethod async
aget_state_cache(config)

Retrieve agent state from cache asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
StateT | None

StateT | None: Cached state or None.

Source code in pyagenity/checkpointer/base_checkpointer.py
118
119
120
121
122
123
124
125
126
127
128
129
@abstractmethod
async def aget_state_cache(self, config: dict[str, Any]) -> StateT | None:
    """
    Retrieve agent state from cache asynchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        StateT | None: Cached state or None.
    """
    raise NotImplementedError
aget_thread abstractmethod async
aget_thread(config)

Retrieve thread info asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
ThreadInfo | None

ThreadInfo | None: Thread information object or None.

Source code in pyagenity/checkpointer/base_checkpointer.py
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
@abstractmethod
async def aget_thread(
    self,
    config: dict[str, Any],
) -> ThreadInfo | None:
    """
    Retrieve thread info asynchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        ThreadInfo | None: Thread information object or None.
    """
    raise NotImplementedError
alist_messages abstractmethod async
alist_messages(config, search=None, offset=None, limit=None)

List messages asynchronously with optional filtering.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
search str

Search string.

None
offset int

Offset for pagination.

None
limit int

Limit for pagination.

None

Returns:

Type Description
list[Message]

list[Message]: List of message objects.

Source code in pyagenity/checkpointer/base_checkpointer.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
@abstractmethod
async def alist_messages(
    self,
    config: dict[str, Any],
    search: str | None = None,
    offset: int | None = None,
    limit: int | None = None,
) -> list[Message]:
    """
    List messages asynchronously with optional filtering.

    Args:
        config (dict): Configuration dictionary.
        search (str, optional): Search string.
        offset (int, optional): Offset for pagination.
        limit (int, optional): Limit for pagination.

    Returns:
        list[Message]: List of message objects.
    """
    raise NotImplementedError
alist_threads abstractmethod async
alist_threads(config, search=None, offset=None, limit=None)

List threads asynchronously with optional filtering.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
search str

Search string.

None
offset int

Offset for pagination.

None
limit int

Limit for pagination.

None

Returns:

Type Description
list[ThreadInfo]

list[ThreadInfo]: List of thread information objects.

Source code in pyagenity/checkpointer/base_checkpointer.py
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
@abstractmethod
async def alist_threads(
    self,
    config: dict[str, Any],
    search: str | None = None,
    offset: int | None = None,
    limit: int | None = None,
) -> list[ThreadInfo]:
    """
    List threads asynchronously with optional filtering.

    Args:
        config (dict): Configuration dictionary.
        search (str, optional): Search string.
        offset (int, optional): Offset for pagination.
        limit (int, optional): Limit for pagination.

    Returns:
        list[ThreadInfo]: List of thread information objects.
    """
    raise NotImplementedError
aput_messages abstractmethod async
aput_messages(config, messages, metadata=None)

Store messages asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
messages list[Message]

List of messages to store.

required
metadata dict

Additional metadata.

None

Returns:

Name Type Description
Any Any

Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
@abstractmethod
async def aput_messages(
    self,
    config: dict[str, Any],
    messages: list[Message],
    metadata: dict[str, Any] | None = None,
) -> Any:
    """
    Store messages asynchronously.

    Args:
        config (dict): Configuration dictionary.
        messages (list[Message]): List of messages to store.
        metadata (dict, optional): Additional metadata.

    Returns:
        Any: Implementation-defined result.
    """
    raise NotImplementedError
aput_state abstractmethod async
aput_state(config, state)

Store agent state asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
state StateT

State object to store.

required

Returns:

Name Type Description
StateT StateT

The stored state object.

Source code in pyagenity/checkpointer/base_checkpointer.py
64
65
66
67
68
69
70
71
72
73
74
75
76
@abstractmethod
async def aput_state(self, config: dict[str, Any], state: StateT) -> StateT:
    """
    Store agent state asynchronously.

    Args:
        config (dict): Configuration dictionary.
        state (StateT): State object to store.

    Returns:
        StateT: The stored state object.
    """
    raise NotImplementedError
aput_state_cache abstractmethod async
aput_state_cache(config, state)

Store agent state in cache asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
state StateT

State object to cache.

required

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
104
105
106
107
108
109
110
111
112
113
114
115
116
@abstractmethod
async def aput_state_cache(self, config: dict[str, Any], state: StateT) -> Any | None:
    """
    Store agent state in cache asynchronously.

    Args:
        config (dict): Configuration dictionary.
        state (StateT): State object to cache.

    Returns:
        Any | None: Implementation-defined result.
    """
    raise NotImplementedError
aput_thread abstractmethod async
aput_thread(config, thread_info)

Store thread info asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
thread_info ThreadInfo

Thread information object.

required

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
@abstractmethod
async def aput_thread(
    self,
    config: dict[str, Any],
    thread_info: ThreadInfo,
) -> Any | None:
    """
    Store thread info asynchronously.

    Args:
        config (dict): Configuration dictionary.
        thread_info (ThreadInfo): Thread information object.

    Returns:
        Any | None: Implementation-defined result.
    """
    raise NotImplementedError
arelease abstractmethod async
arelease()

Release resources asynchronously.

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
482
483
484
485
486
487
488
489
490
@abstractmethod
async def arelease(self) -> Any | None:
    """
    Release resources asynchronously.

    Returns:
        Any | None: Implementation-defined result.
    """
    raise NotImplementedError
asetup abstractmethod async
asetup()

Asynchronous setup method for checkpointer.

Returns:

Name Type Description
Any Any

Implementation-defined setup result.

Source code in pyagenity/checkpointer/base_checkpointer.py
51
52
53
54
55
56
57
58
59
@abstractmethod
async def asetup(self) -> Any:
    """
    Asynchronous setup method for checkpointer.

    Returns:
        Any: Implementation-defined setup result.
    """
    raise NotImplementedError
clean_thread
clean_thread(config)

Clean/delete thread synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
458
459
460
461
462
463
464
465
466
467
468
def clean_thread(self, config: dict[str, Any]) -> Any | None:
    """
    Clean/delete thread synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        Any | None: Implementation-defined result.
    """
    return run_coroutine(self.aclean_thread(config))
clear_state
clear_state(config)

Clear agent state synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Name Type Description
Any Any

Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
159
160
161
162
163
164
165
166
167
168
169
def clear_state(self, config: dict[str, Any]) -> Any:
    """
    Clear agent state synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        Any: Implementation-defined result.
    """
    return run_coroutine(self.aclear_state(config))
delete_message
delete_message(config, message_id)

Delete a specific message synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
message_id str | int

Message identifier.

required

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
324
325
326
327
328
329
330
331
332
333
334
335
def delete_message(self, config: dict[str, Any], message_id: str | int) -> Any | None:
    """
    Delete a specific message synchronously.

    Args:
        config (dict): Configuration dictionary.
        message_id (str|int): Message identifier.

    Returns:
        Any | None: Implementation-defined result.
    """
    return run_coroutine(self.adelete_message(config, message_id))
get_message
get_message(config, message_id)

Retrieve a specific message synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Name Type Description
Message Message

Retrieved message object.

Source code in pyagenity/checkpointer/base_checkpointer.py
291
292
293
294
295
296
297
298
299
300
301
def get_message(self, config: dict[str, Any], message_id: str | int) -> Message:
    """
    Retrieve a specific message synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        Message: Retrieved message object.
    """
    return run_coroutine(self.aget_message(config, message_id))
get_state
get_state(config)

Retrieve agent state synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
StateT | None

StateT | None: Retrieved state or None.

Source code in pyagenity/checkpointer/base_checkpointer.py
147
148
149
150
151
152
153
154
155
156
157
def get_state(self, config: dict[str, Any]) -> StateT | None:
    """
    Retrieve agent state synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        StateT | None: Retrieved state or None.
    """
    return run_coroutine(self.aget_state(config))
get_state_cache
get_state_cache(config)

Retrieve agent state from cache synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
StateT | None

StateT | None: Cached state or None.

Source code in pyagenity/checkpointer/base_checkpointer.py
184
185
186
187
188
189
190
191
192
193
194
def get_state_cache(self, config: dict[str, Any]) -> StateT | None:
    """
    Retrieve agent state from cache synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        StateT | None: Cached state or None.
    """
    return run_coroutine(self.aget_state_cache(config))
get_thread
get_thread(config)

Retrieve thread info synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
ThreadInfo | None

ThreadInfo | None: Thread information object or None.

Source code in pyagenity/checkpointer/base_checkpointer.py
425
426
427
428
429
430
431
432
433
434
435
def get_thread(self, config: dict[str, Any]) -> ThreadInfo | None:
    """
    Retrieve thread info synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        ThreadInfo | None: Thread information object or None.
    """
    return run_coroutine(self.aget_thread(config))
list_messages
list_messages(config, search=None, offset=None, limit=None)

List messages synchronously with optional filtering.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
search str

Search string.

None
offset int

Offset for pagination.

None
limit int

Limit for pagination.

None

Returns:

Type Description
list[Message]

list[Message]: List of message objects.

Source code in pyagenity/checkpointer/base_checkpointer.py
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
def list_messages(
    self,
    config: dict[str, Any],
    search: str | None = None,
    offset: int | None = None,
    limit: int | None = None,
) -> list[Message]:
    """
    List messages synchronously with optional filtering.

    Args:
        config (dict): Configuration dictionary.
        search (str, optional): Search string.
        offset (int, optional): Offset for pagination.
        limit (int, optional): Limit for pagination.

    Returns:
        list[Message]: List of message objects.
    """
    return run_coroutine(self.alist_messages(config, search, offset, limit))
list_threads
list_threads(config, search=None, offset=None, limit=None)

List threads synchronously with optional filtering.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
search str

Search string.

None
offset int

Offset for pagination.

None
limit int

Limit for pagination.

None

Returns:

Type Description
list[ThreadInfo]

list[ThreadInfo]: List of thread information objects.

Source code in pyagenity/checkpointer/base_checkpointer.py
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
def list_threads(
    self,
    config: dict[str, Any],
    search: str | None = None,
    offset: int | None = None,
    limit: int | None = None,
) -> list[ThreadInfo]:
    """
    List threads synchronously with optional filtering.

    Args:
        config (dict): Configuration dictionary.
        search (str, optional): Search string.
        offset (int, optional): Offset for pagination.
        limit (int, optional): Limit for pagination.

    Returns:
        list[ThreadInfo]: List of thread information objects.
    """
    return run_coroutine(self.alist_threads(config, search, offset, limit))
put_messages
put_messages(config, messages, metadata=None)

Store messages synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
messages list[Message]

List of messages to store.

required
metadata dict

Additional metadata.

None

Returns:

Name Type Description
Any Any

Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
def put_messages(
    self,
    config: dict[str, Any],
    messages: list[Message],
    metadata: dict[str, Any] | None = None,
) -> Any:
    """
    Store messages synchronously.

    Args:
        config (dict): Configuration dictionary.
        messages (list[Message]): List of messages to store.
        metadata (dict, optional): Additional metadata.

    Returns:
        Any: Implementation-defined result.
    """
    return run_coroutine(self.aput_messages(config, messages, metadata))
put_state
put_state(config, state)

Store agent state synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
state StateT

State object to store.

required

Returns:

Name Type Description
StateT StateT

The stored state object.

Source code in pyagenity/checkpointer/base_checkpointer.py
134
135
136
137
138
139
140
141
142
143
144
145
def put_state(self, config: dict[str, Any], state: StateT) -> StateT:
    """
    Store agent state synchronously.

    Args:
        config (dict): Configuration dictionary.
        state (StateT): State object to store.

    Returns:
        StateT: The stored state object.
    """
    return run_coroutine(self.aput_state(config, state))
put_state_cache
put_state_cache(config, state)

Store agent state in cache synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
state StateT

State object to cache.

required

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
171
172
173
174
175
176
177
178
179
180
181
182
def put_state_cache(self, config: dict[str, Any], state: StateT) -> Any | None:
    """
    Store agent state in cache synchronously.

    Args:
        config (dict): Configuration dictionary.
        state (StateT): State object to cache.

    Returns:
        Any | None: Implementation-defined result.
    """
    return run_coroutine(self.aput_state_cache(config, state))
put_thread
put_thread(config, thread_info)

Store thread info synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
thread_info ThreadInfo

Thread information object.

required

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
412
413
414
415
416
417
418
419
420
421
422
423
def put_thread(self, config: dict[str, Any], thread_info: ThreadInfo) -> Any | None:
    """
    Store thread info synchronously.

    Args:
        config (dict): Configuration dictionary.
        thread_info (ThreadInfo): Thread information object.

    Returns:
        Any | None: Implementation-defined result.
    """
    return run_coroutine(self.aput_thread(config, thread_info))
release
release()

Release resources synchronously.

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
473
474
475
476
477
478
479
480
def release(self) -> Any | None:
    """
    Release resources synchronously.

    Returns:
        Any | None: Implementation-defined result.
    """
    return run_coroutine(self.arelease())
setup
setup()

Synchronous setup method for checkpointer.

Returns:

Name Type Description
Any Any

Implementation-defined setup result.

Source code in pyagenity/checkpointer/base_checkpointer.py
42
43
44
45
46
47
48
49
def setup(self) -> Any:
    """
    Synchronous setup method for checkpointer.

    Returns:
        Any: Implementation-defined setup result.
    """
    return run_coroutine(self.asetup())
InMemoryCheckpointer

Bases: BaseCheckpointer[StateT]

In-memory implementation of BaseCheckpointer.

Stores all agent state, messages, and thread info in memory using Python dictionaries. Data is lost when the process ends. Designed for testing and ephemeral use cases. Async-first design using asyncio locks for concurrent access.

Attributes:

Name Type Description
_states dict

Stores agent states by thread key.

_state_cache dict

Stores cached agent states by thread key.

_messages dict

Stores messages by thread key.

_message_metadata dict

Stores message metadata by thread key.

_threads dict

Stores thread info by thread key.

_state_lock Lock

Lock for state operations.

_messages_lock Lock

Lock for message operations.

_threads_lock Lock

Lock for thread operations.

Methods:

Name Description
__init__

Initialize all in-memory storage and locks.

aclean_thread

Clean/delete thread asynchronously.

aclear_state

Clear state asynchronously.

adelete_message

Delete a specific message asynchronously.

aget_message

Retrieve a specific message asynchronously.

aget_state

Retrieve state asynchronously.

aget_state_cache

Retrieve state cache asynchronously.

aget_thread

Retrieve thread info asynchronously.

alist_messages

List messages asynchronously with optional filtering.

alist_threads

List all threads asynchronously with optional filtering.

aput_messages

Store messages asynchronously.

aput_state

Store state asynchronously.

aput_state_cache

Store state cache asynchronously.

aput_thread

Store thread info asynchronously.

arelease

Release resources asynchronously.

asetup

Asynchronous setup method. No setup required for in-memory checkpointer.

clean_thread

Clean/delete thread synchronously.

clear_state

Clear state synchronously.

delete_message

Delete a specific message synchronously.

get_message

Retrieve a specific message synchronously.

get_state

Retrieve state synchronously.

get_state_cache

Retrieve state cache synchronously.

get_thread

Retrieve thread info synchronously.

list_messages

List messages synchronously with optional filtering.

list_threads

List all threads synchronously with optional filtering.

put_messages

Store messages synchronously.

put_state

Store state synchronously.

put_state_cache

Store state cache synchronously.

put_thread

Store thread info synchronously.

release

Release resources synchronously.

setup

Synchronous setup method. No setup required for in-memory checkpointer.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
class InMemoryCheckpointer[StateT: AgentState](BaseCheckpointer[StateT]):
    """
    In-memory implementation of BaseCheckpointer.

    Stores all agent state, messages, and thread info in memory using Python dictionaries.
    Data is lost when the process ends. Designed for testing and ephemeral use cases.
    Async-first design using asyncio locks for concurrent access.

    Args:
        None

    Attributes:
        _states (dict): Stores agent states by thread key.
        _state_cache (dict): Stores cached agent states by thread key.
        _messages (dict): Stores messages by thread key.
        _message_metadata (dict): Stores message metadata by thread key.
        _threads (dict): Stores thread info by thread key.
        _state_lock (asyncio.Lock): Lock for state operations.
        _messages_lock (asyncio.Lock): Lock for message operations.
        _threads_lock (asyncio.Lock): Lock for thread operations.
    """

    def __init__(self):
        """
        Initialize all in-memory storage and locks.
        """
        # State storage
        self._states: dict[str, StateT] = {}
        self._state_cache: dict[str, StateT] = {}

        # Message storage - organized by config key
        self._messages: dict[str, list[Message]] = defaultdict(list)
        self._message_metadata: dict[str, dict[str, Any]] = {}

        # Thread storage
        self._threads: dict[str, dict[str, Any]] = {}

        # Async locks for concurrent access
        self._state_lock = asyncio.Lock()
        self._messages_lock = asyncio.Lock()
        self._threads_lock = asyncio.Lock()

    def setup(self) -> Any:
        """
        Synchronous setup method. No setup required for in-memory checkpointer.
        """
        logger.debug("InMemoryCheckpointer setup not required")

    async def asetup(self) -> Any:
        """
        Asynchronous setup method. No setup required for in-memory checkpointer.
        """
        logger.debug("InMemoryCheckpointer async setup not required")

    def _get_config_key(self, config: dict[str, Any]) -> str:
        """
        Generate a string key from config dict for storage indexing.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            str: Key for indexing storage.
        """
        """Generate a string key from config dict for storage indexing."""
        # Sort keys for consistent hashing
        thread_id = config.get("thread_id", "")
        return str(thread_id)

    # -------------------------
    # State methods Async
    # -------------------------
    async def aput_state(self, config: dict[str, Any], state: StateT) -> StateT:
        """
        Store state asynchronously.

        Args:
            config (dict): Configuration dictionary.
            state (StateT): State object to store.

        Returns:
            StateT: The stored state object.
        """
        """Store state asynchronously."""
        key = self._get_config_key(config)
        async with self._state_lock:
            self._states[key] = state
            logger.debug(f"Stored state for key: {key}")
            return state

    async def aget_state(self, config: dict[str, Any]) -> StateT | None:
        """
        Retrieve state asynchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            StateT | None: Retrieved state or None.
        """
        """Retrieve state asynchronously."""
        key = self._get_config_key(config)
        async with self._state_lock:
            state = self._states.get(key)
            logger.debug(f"Retrieved state for key: {key}, found: {state is not None}")
            return state

    async def aclear_state(self, config: dict[str, Any]) -> bool:
        """
        Clear state asynchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            bool: True if cleared.
        """
        """Clear state asynchronously."""
        key = self._get_config_key(config)
        async with self._state_lock:
            if key in self._states:
                del self._states[key]
                logger.debug(f"Cleared state for key: {key}")
            return True

    async def aput_state_cache(self, config: dict[str, Any], state: StateT) -> StateT:
        """
        Store state cache asynchronously.

        Args:
            config (dict): Configuration dictionary.
            state (StateT): State object to cache.

        Returns:
            StateT: The cached state object.
        """
        """Store state cache asynchronously."""
        key = self._get_config_key(config)
        async with self._state_lock:
            self._state_cache[key] = state
            logger.debug(f"Stored state cache for key: {key}")
            return state

    async def aget_state_cache(self, config: dict[str, Any]) -> StateT | None:
        """
        Retrieve state cache asynchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            StateT | None: Cached state or None.
        """
        """Retrieve state cache asynchronously."""
        key = self._get_config_key(config)
        async with self._state_lock:
            cache = self._state_cache.get(key)
            logger.debug(f"Retrieved state cache for key: {key}, found: {cache is not None}")
            return cache

    # -------------------------
    # State methods Sync
    # -------------------------
    def put_state(self, config: dict[str, Any], state: StateT) -> StateT:
        """
        Store state synchronously.

        Args:
            config (dict): Configuration dictionary.
            state (StateT): State object to store.

        Returns:
            StateT: The stored state object.
        """
        """Store state synchronously."""
        key = self._get_config_key(config)
        # For sync methods, we'll use a simple approach without locks
        # In a real async-first system, sync methods might not be used
        self._states[key] = state
        logger.debug(f"Stored state for key: {key}")
        return state

    def get_state(self, config: dict[str, Any]) -> StateT | None:
        """
        Retrieve state synchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            StateT | None: Retrieved state or None.
        """
        """Retrieve state synchronously."""
        key = self._get_config_key(config)
        state = self._states.get(key)
        logger.debug(f"Retrieved state for key: {key}, found: {state is not None}")
        return state

    def clear_state(self, config: dict[str, Any]) -> bool:
        """
        Clear state synchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            bool: True if cleared.
        """
        """Clear state synchronously."""
        key = self._get_config_key(config)
        if key in self._states:
            del self._states[key]
            logger.debug(f"Cleared state for key: {key}")
        return True

    def put_state_cache(self, config: dict[str, Any], state: StateT) -> StateT:
        """
        Store state cache synchronously.

        Args:
            config (dict): Configuration dictionary.
            state (StateT): State object to cache.

        Returns:
            StateT: The cached state object.
        """
        """Store state cache synchronously."""
        key = self._get_config_key(config)
        self._state_cache[key] = state
        logger.debug(f"Stored state cache for key: {key}")
        return state

    def get_state_cache(self, config: dict[str, Any]) -> StateT | None:
        """
        Retrieve state cache synchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            StateT | None: Cached state or None.
        """
        """Retrieve state cache synchronously."""
        key = self._get_config_key(config)
        cache = self._state_cache.get(key)
        logger.debug(f"Retrieved state cache for key: {key}, found: {cache is not None}")
        return cache

    # -------------------------
    # Message methods async
    # -------------------------
    async def aput_messages(
        self,
        config: dict[str, Any],
        messages: list[Message],
        metadata: dict[str, Any] | None = None,
    ) -> bool:
        """
        Store messages asynchronously.

        Args:
            config (dict): Configuration dictionary.
            messages (list[Message]): List of messages to store.
            metadata (dict, optional): Additional metadata.

        Returns:
            bool: True if stored.
        """
        key = self._get_config_key(config)
        async with self._messages_lock:
            self._messages[key].extend(messages)
            if metadata:
                self._message_metadata[key] = metadata
            logger.debug(f"Stored {len(messages)} messages for key: {key}")
            return True

    async def aget_message(self, config: dict[str, Any], message_id: str | int) -> Message:
        """
        Retrieve a specific message asynchronously.

        Args:
            config (dict): Configuration dictionary.
            message_id (str|int): Message identifier.

        Returns:
            Message: Retrieved message object.

        Raises:
            IndexError: If message not found.
        """
        """Retrieve a specific message asynchronously."""
        key = self._get_config_key(config)
        async with self._messages_lock:
            messages = self._messages.get(key, [])
            for msg in messages:
                if msg.message_id == message_id:
                    return msg
            raise IndexError(f"Message with ID {message_id} not found for config key: {key}")

    async def alist_messages(
        self,
        config: dict[str, Any],
        search: str | None = None,
        offset: int | None = None,
        limit: int | None = None,
    ) -> list[Message]:
        """
        List messages asynchronously with optional filtering.

        Args:
            config (dict): Configuration dictionary.
            search (str, optional): Search string.
            offset (int, optional): Offset for pagination.
            limit (int, optional): Limit for pagination.

        Returns:
            list[Message]: List of message objects.
        """
        key = self._get_config_key(config)
        async with self._messages_lock:
            messages = self._messages.get(key, [])

            # Apply search filter if provided
            if search:
                # Simple string search in message content
                messages = [
                    msg
                    for msg in messages
                    if hasattr(msg, "content") and search.lower() in str(msg.content).lower()
                ]

            # Apply offset and limit
            start = offset or 0
            end = (start + limit) if limit else None
            return messages[start:end]

    async def adelete_message(self, config: dict[str, Any], message_id: str | int) -> bool:
        """
        Delete a specific message asynchronously.

        Args:
            config (dict): Configuration dictionary.
            message_id (str|int): Message identifier.

        Returns:
            bool: True if deleted.

        Raises:
            IndexError: If message not found.
        """
        """Delete a specific message asynchronously."""
        key = self._get_config_key(config)
        async with self._messages_lock:
            messages = self._messages.get(key, [])
            for msg in messages:
                if msg.message_id == message_id:
                    messages.remove(msg)
                    logger.debug(f"Deleted message with ID {message_id} for key: {key}")
                    return True
            raise IndexError(f"Message with ID {message_id} not found for config key: {key}")

    # -------------------------
    # Message methods sync
    # -------------------------
    def put_messages(
        self,
        config: dict[str, Any],
        messages: list[Message],
        metadata: dict[str, Any] | None = None,
    ) -> bool:
        """
        Store messages synchronously.

        Args:
            config (dict): Configuration dictionary.
            messages (list[Message]): List of messages to store.
            metadata (dict, optional): Additional metadata.

        Returns:
            bool: True if stored.
        """
        key = self._get_config_key(config)
        self._messages[key].extend(messages)
        if metadata:
            self._message_metadata[key] = metadata

        logger.debug(f"Stored {len(messages)} messages for key: {key}")
        return True

    def get_message(self, config: dict[str, Any], message_id: str | int) -> Message:
        """
        Retrieve a specific message synchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            Message: Latest message object.

        Raises:
            IndexError: If no messages found.
        """
        """Retrieve the latest message synchronously."""
        key = self._get_config_key(config)
        messages = self._messages.get(key, [])
        for msg in messages:
            if msg.message_id == message_id:
                return msg
        raise IndexError(f"Message with ID {message_id} not found for config key: {key}")

    def list_messages(
        self,
        config: dict[str, Any],
        search: str | None = None,
        offset: int | None = None,
        limit: int | None = None,
    ) -> list[Message]:
        """
        List messages synchronously with optional filtering.

        Args:
            config (dict): Configuration dictionary.
            search (str, optional): Search string.
            offset (int, optional): Offset for pagination.
            limit (int, optional): Limit for pagination.

        Returns:
            list[Message]: List of message objects.
        """
        key = self._get_config_key(config)
        messages = self._messages.get(key, [])

        # Apply search filter if provided
        if search:
            messages = [
                msg
                for msg in messages
                if hasattr(msg, "content") and search.lower() in str(msg.content).lower()
            ]

        # Apply offset and limit
        start = offset or 0
        end = (start + limit) if limit else None
        return messages[start:end]

    def delete_message(self, config: dict[str, Any], message_id: str | int) -> bool:
        """
        Delete a specific message synchronously.

        Args:
            config (dict): Configuration dictionary.
            message_id (str|int): Message identifier.

        Returns:
            bool: True if deleted.

        Raises:
            IndexError: If message not found.
        """
        """Delete a specific message synchronously."""
        key = self._get_config_key(config)
        messages = self._messages.get(key, [])
        for msg in messages:
            if msg.message_id == message_id:
                messages.remove(msg)
                logger.debug(f"Deleted message with ID {message_id} for key: {key}")
                return True
        raise IndexError(f"Message with ID {message_id} not found for config key: {key}")

    # -------------------------
    # Thread methods async
    # -------------------------
    async def aput_thread(
        self,
        config: dict[str, Any],
        thread_info: ThreadInfo,
    ) -> bool:
        """
        Store thread info asynchronously.

        Args:
            config (dict): Configuration dictionary.
            thread_info (ThreadInfo): Thread information object.

        Returns:
            bool: True if stored.
        """
        key = self._get_config_key(config)
        async with self._threads_lock:
            self._threads[key] = thread_info.model_dump()
            logger.debug(f"Stored thread info for key: {key}")
            return True

    async def aget_thread(
        self,
        config: dict[str, Any],
    ) -> ThreadInfo | None:
        """
        Retrieve thread info asynchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            ThreadInfo | None: Thread information object or None.
        """
        key = self._get_config_key(config)
        async with self._threads_lock:
            thread = self._threads.get(key)
            logger.debug(f"Retrieved thread for key: {key}, found: {thread is not None}")
            return ThreadInfo.model_validate(thread) if thread else None

    async def alist_threads(
        self,
        config: dict[str, Any],
        search: str | None = None,
        offset: int | None = None,
        limit: int | None = None,
    ) -> list[ThreadInfo]:
        """
        List all threads asynchronously with optional filtering.

        Args:
            config (dict): Configuration dictionary.
            search (str, optional): Search string.
            offset (int, optional): Offset for pagination.
            limit (int, optional): Limit for pagination.

        Returns:
            list[ThreadInfo]: List of thread information objects.
        """
        async with self._threads_lock:
            threads = list(self._threads.values())

            # Apply search filter if provided
            if search:
                threads = [
                    thread
                    for thread in threads
                    if any(search.lower() in str(value).lower() for value in thread.values())
                ]

            # Apply offset and limit
            start = offset or 0
            end = (start + limit) if limit else None
            return [ThreadInfo.model_validate(thread) for thread in threads[start:end]]

    async def aclean_thread(self, config: dict[str, Any]) -> bool:
        """
        Clean/delete thread asynchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            bool: True if cleaned.
        """
        """Clean/delete thread asynchronously."""
        key = self._get_config_key(config)
        async with self._threads_lock:
            if key in self._threads:
                del self._threads[key]
                logger.debug(f"Cleaned thread for key: {key}")
                return True
        return False

    # -------------------------
    # Thread methods sync
    # -------------------------
    def put_thread(self, config: dict[str, Any], thread_info: ThreadInfo) -> bool:
        """
        Store thread info synchronously.

        Args:
            config (dict): Configuration dictionary.
            thread_info (ThreadInfo): Thread information object.

        Returns:
            bool: True if stored.
        """
        """Store thread info synchronously."""
        key = self._get_config_key(config)
        self._threads[key] = thread_info.model_dump()
        logger.debug(f"Stored thread info for key: {key}")
        return True

    def get_thread(self, config: dict[str, Any]) -> ThreadInfo | None:
        """
        Retrieve thread info synchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            ThreadInfo | None: Thread information object or None.
        """
        """Retrieve thread info synchronously."""
        key = self._get_config_key(config)
        thread = self._threads.get(key)
        logger.debug(f"Retrieved thread for key: {key}, found: {thread is not None}")
        return ThreadInfo.model_validate(thread) if thread else None

    def list_threads(
        self,
        config: dict[str, Any],
        search: str | None = None,
        offset: int | None = None,
        limit: int | None = None,
    ) -> list[ThreadInfo]:
        """
        List all threads synchronously with optional filtering.

        Args:
            config (dict): Configuration dictionary.
            search (str, optional): Search string.
            offset (int, optional): Offset for pagination.
            limit (int, optional): Limit for pagination.

        Returns:
            list[ThreadInfo]: List of thread information objects.
        """
        threads = list(self._threads.values())

        # Apply search filter if provided
        if search:
            threads = [
                thread
                for thread in threads
                if any(search.lower() in str(value).lower() for value in thread.values())
            ]

        # Apply offset and limit
        start = offset or 0
        end = (start + limit) if limit else None
        return [ThreadInfo.model_validate(thread) for thread in threads[start:end]]

    def clean_thread(self, config: dict[str, Any]) -> bool:
        """
        Clean/delete thread synchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            bool: True if cleaned.
        """
        """Clean/delete thread synchronously."""
        key = self._get_config_key(config)
        if key in self._threads:
            del self._threads[key]
            logger.debug(f"Cleaned thread for key: {key}")
            return True
        return False

    # -------------------------
    # Clean Resources
    # -------------------------
    async def arelease(self) -> bool:
        """
        Release resources asynchronously.

        Returns:
            bool: True if released.
        """
        """Release resources asynchronously."""
        async with self._state_lock, self._messages_lock, self._threads_lock:
            self._states.clear()
            self._state_cache.clear()
            self._messages.clear()
            self._message_metadata.clear()
            self._threads.clear()
            logger.info("Released all in-memory resources")
            return True

    def release(self) -> bool:
        """
        Release resources synchronously.

        Returns:
            bool: True if released.
        """
        """Release resources synchronously."""
        self._states.clear()
        self._state_cache.clear()
        self._messages.clear()
        self._message_metadata.clear()
        self._threads.clear()
        logger.info("Released all in-memory resources")
        return True
Functions
__init__
__init__()

Initialize all in-memory storage and locks.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def __init__(self):
    """
    Initialize all in-memory storage and locks.
    """
    # State storage
    self._states: dict[str, StateT] = {}
    self._state_cache: dict[str, StateT] = {}

    # Message storage - organized by config key
    self._messages: dict[str, list[Message]] = defaultdict(list)
    self._message_metadata: dict[str, dict[str, Any]] = {}

    # Thread storage
    self._threads: dict[str, dict[str, Any]] = {}

    # Async locks for concurrent access
    self._state_lock = asyncio.Lock()
    self._messages_lock = asyncio.Lock()
    self._threads_lock = asyncio.Lock()
aclean_thread async
aclean_thread(config)

Clean/delete thread asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Name Type Description
bool bool

True if cleaned.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
async def aclean_thread(self, config: dict[str, Any]) -> bool:
    """
    Clean/delete thread asynchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        bool: True if cleaned.
    """
    """Clean/delete thread asynchronously."""
    key = self._get_config_key(config)
    async with self._threads_lock:
        if key in self._threads:
            del self._threads[key]
            logger.debug(f"Cleaned thread for key: {key}")
            return True
    return False
aclear_state async
aclear_state(config)

Clear state asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Name Type Description
bool bool

True if cleared.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
async def aclear_state(self, config: dict[str, Any]) -> bool:
    """
    Clear state asynchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        bool: True if cleared.
    """
    """Clear state asynchronously."""
    key = self._get_config_key(config)
    async with self._state_lock:
        if key in self._states:
            del self._states[key]
            logger.debug(f"Cleared state for key: {key}")
        return True
adelete_message async
adelete_message(config, message_id)

Delete a specific message asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
message_id str | int

Message identifier.

required

Returns:

Name Type Description
bool bool

True if deleted.

Raises:

Type Description
IndexError

If message not found.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
async def adelete_message(self, config: dict[str, Any], message_id: str | int) -> bool:
    """
    Delete a specific message asynchronously.

    Args:
        config (dict): Configuration dictionary.
        message_id (str|int): Message identifier.

    Returns:
        bool: True if deleted.

    Raises:
        IndexError: If message not found.
    """
    """Delete a specific message asynchronously."""
    key = self._get_config_key(config)
    async with self._messages_lock:
        messages = self._messages.get(key, [])
        for msg in messages:
            if msg.message_id == message_id:
                messages.remove(msg)
                logger.debug(f"Deleted message with ID {message_id} for key: {key}")
                return True
        raise IndexError(f"Message with ID {message_id} not found for config key: {key}")
aget_message async
aget_message(config, message_id)

Retrieve a specific message asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
message_id str | int

Message identifier.

required

Returns:

Name Type Description
Message Message

Retrieved message object.

Raises:

Type Description
IndexError

If message not found.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
async def aget_message(self, config: dict[str, Any], message_id: str | int) -> Message:
    """
    Retrieve a specific message asynchronously.

    Args:
        config (dict): Configuration dictionary.
        message_id (str|int): Message identifier.

    Returns:
        Message: Retrieved message object.

    Raises:
        IndexError: If message not found.
    """
    """Retrieve a specific message asynchronously."""
    key = self._get_config_key(config)
    async with self._messages_lock:
        messages = self._messages.get(key, [])
        for msg in messages:
            if msg.message_id == message_id:
                return msg
        raise IndexError(f"Message with ID {message_id} not found for config key: {key}")
aget_state async
aget_state(config)

Retrieve state asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
StateT | None

StateT | None: Retrieved state or None.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
async def aget_state(self, config: dict[str, Any]) -> StateT | None:
    """
    Retrieve state asynchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        StateT | None: Retrieved state or None.
    """
    """Retrieve state asynchronously."""
    key = self._get_config_key(config)
    async with self._state_lock:
        state = self._states.get(key)
        logger.debug(f"Retrieved state for key: {key}, found: {state is not None}")
        return state
aget_state_cache async
aget_state_cache(config)

Retrieve state cache asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
StateT | None

StateT | None: Cached state or None.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
async def aget_state_cache(self, config: dict[str, Any]) -> StateT | None:
    """
    Retrieve state cache asynchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        StateT | None: Cached state or None.
    """
    """Retrieve state cache asynchronously."""
    key = self._get_config_key(config)
    async with self._state_lock:
        cache = self._state_cache.get(key)
        logger.debug(f"Retrieved state cache for key: {key}, found: {cache is not None}")
        return cache
aget_thread async
aget_thread(config)

Retrieve thread info asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
ThreadInfo | None

ThreadInfo | None: Thread information object or None.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
async def aget_thread(
    self,
    config: dict[str, Any],
) -> ThreadInfo | None:
    """
    Retrieve thread info asynchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        ThreadInfo | None: Thread information object or None.
    """
    key = self._get_config_key(config)
    async with self._threads_lock:
        thread = self._threads.get(key)
        logger.debug(f"Retrieved thread for key: {key}, found: {thread is not None}")
        return ThreadInfo.model_validate(thread) if thread else None
alist_messages async
alist_messages(config, search=None, offset=None, limit=None)

List messages asynchronously with optional filtering.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
search str

Search string.

None
offset int

Offset for pagination.

None
limit int

Limit for pagination.

None

Returns:

Type Description
list[Message]

list[Message]: List of message objects.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
async def alist_messages(
    self,
    config: dict[str, Any],
    search: str | None = None,
    offset: int | None = None,
    limit: int | None = None,
) -> list[Message]:
    """
    List messages asynchronously with optional filtering.

    Args:
        config (dict): Configuration dictionary.
        search (str, optional): Search string.
        offset (int, optional): Offset for pagination.
        limit (int, optional): Limit for pagination.

    Returns:
        list[Message]: List of message objects.
    """
    key = self._get_config_key(config)
    async with self._messages_lock:
        messages = self._messages.get(key, [])

        # Apply search filter if provided
        if search:
            # Simple string search in message content
            messages = [
                msg
                for msg in messages
                if hasattr(msg, "content") and search.lower() in str(msg.content).lower()
            ]

        # Apply offset and limit
        start = offset or 0
        end = (start + limit) if limit else None
        return messages[start:end]
alist_threads async
alist_threads(config, search=None, offset=None, limit=None)

List all threads asynchronously with optional filtering.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
search str

Search string.

None
offset int

Offset for pagination.

None
limit int

Limit for pagination.

None

Returns:

Type Description
list[ThreadInfo]

list[ThreadInfo]: List of thread information objects.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
async def alist_threads(
    self,
    config: dict[str, Any],
    search: str | None = None,
    offset: int | None = None,
    limit: int | None = None,
) -> list[ThreadInfo]:
    """
    List all threads asynchronously with optional filtering.

    Args:
        config (dict): Configuration dictionary.
        search (str, optional): Search string.
        offset (int, optional): Offset for pagination.
        limit (int, optional): Limit for pagination.

    Returns:
        list[ThreadInfo]: List of thread information objects.
    """
    async with self._threads_lock:
        threads = list(self._threads.values())

        # Apply search filter if provided
        if search:
            threads = [
                thread
                for thread in threads
                if any(search.lower() in str(value).lower() for value in thread.values())
            ]

        # Apply offset and limit
        start = offset or 0
        end = (start + limit) if limit else None
        return [ThreadInfo.model_validate(thread) for thread in threads[start:end]]
aput_messages async
aput_messages(config, messages, metadata=None)

Store messages asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
messages list[Message]

List of messages to store.

required
metadata dict

Additional metadata.

None

Returns:

Name Type Description
bool bool

True if stored.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
async def aput_messages(
    self,
    config: dict[str, Any],
    messages: list[Message],
    metadata: dict[str, Any] | None = None,
) -> bool:
    """
    Store messages asynchronously.

    Args:
        config (dict): Configuration dictionary.
        messages (list[Message]): List of messages to store.
        metadata (dict, optional): Additional metadata.

    Returns:
        bool: True if stored.
    """
    key = self._get_config_key(config)
    async with self._messages_lock:
        self._messages[key].extend(messages)
        if metadata:
            self._message_metadata[key] = metadata
        logger.debug(f"Stored {len(messages)} messages for key: {key}")
        return True
aput_state async
aput_state(config, state)

Store state asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
state StateT

State object to store.

required

Returns:

Name Type Description
StateT StateT

The stored state object.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
async def aput_state(self, config: dict[str, Any], state: StateT) -> StateT:
    """
    Store state asynchronously.

    Args:
        config (dict): Configuration dictionary.
        state (StateT): State object to store.

    Returns:
        StateT: The stored state object.
    """
    """Store state asynchronously."""
    key = self._get_config_key(config)
    async with self._state_lock:
        self._states[key] = state
        logger.debug(f"Stored state for key: {key}")
        return state
aput_state_cache async
aput_state_cache(config, state)

Store state cache asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
state StateT

State object to cache.

required

Returns:

Name Type Description
StateT StateT

The cached state object.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
async def aput_state_cache(self, config: dict[str, Any], state: StateT) -> StateT:
    """
    Store state cache asynchronously.

    Args:
        config (dict): Configuration dictionary.
        state (StateT): State object to cache.

    Returns:
        StateT: The cached state object.
    """
    """Store state cache asynchronously."""
    key = self._get_config_key(config)
    async with self._state_lock:
        self._state_cache[key] = state
        logger.debug(f"Stored state cache for key: {key}")
        return state
aput_thread async
aput_thread(config, thread_info)

Store thread info asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
thread_info ThreadInfo

Thread information object.

required

Returns:

Name Type Description
bool bool

True if stored.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
async def aput_thread(
    self,
    config: dict[str, Any],
    thread_info: ThreadInfo,
) -> bool:
    """
    Store thread info asynchronously.

    Args:
        config (dict): Configuration dictionary.
        thread_info (ThreadInfo): Thread information object.

    Returns:
        bool: True if stored.
    """
    key = self._get_config_key(config)
    async with self._threads_lock:
        self._threads[key] = thread_info.model_dump()
        logger.debug(f"Stored thread info for key: {key}")
        return True
arelease async
arelease()

Release resources asynchronously.

Returns:

Name Type Description
bool bool

True if released.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
async def arelease(self) -> bool:
    """
    Release resources asynchronously.

    Returns:
        bool: True if released.
    """
    """Release resources asynchronously."""
    async with self._state_lock, self._messages_lock, self._threads_lock:
        self._states.clear()
        self._state_cache.clear()
        self._messages.clear()
        self._message_metadata.clear()
        self._threads.clear()
        logger.info("Released all in-memory resources")
        return True
asetup async
asetup()

Asynchronous setup method. No setup required for in-memory checkpointer.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
68
69
70
71
72
async def asetup(self) -> Any:
    """
    Asynchronous setup method. No setup required for in-memory checkpointer.
    """
    logger.debug("InMemoryCheckpointer async setup not required")
clean_thread
clean_thread(config)

Clean/delete thread synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Name Type Description
bool bool

True if cleaned.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
def clean_thread(self, config: dict[str, Any]) -> bool:
    """
    Clean/delete thread synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        bool: True if cleaned.
    """
    """Clean/delete thread synchronously."""
    key = self._get_config_key(config)
    if key in self._threads:
        del self._threads[key]
        logger.debug(f"Cleaned thread for key: {key}")
        return True
    return False
clear_state
clear_state(config)

Clear state synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Name Type Description
bool bool

True if cleared.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
def clear_state(self, config: dict[str, Any]) -> bool:
    """
    Clear state synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        bool: True if cleared.
    """
    """Clear state synchronously."""
    key = self._get_config_key(config)
    if key in self._states:
        del self._states[key]
        logger.debug(f"Cleared state for key: {key}")
    return True
delete_message
delete_message(config, message_id)

Delete a specific message synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
message_id str | int

Message identifier.

required

Returns:

Name Type Description
bool bool

True if deleted.

Raises:

Type Description
IndexError

If message not found.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
def delete_message(self, config: dict[str, Any], message_id: str | int) -> bool:
    """
    Delete a specific message synchronously.

    Args:
        config (dict): Configuration dictionary.
        message_id (str|int): Message identifier.

    Returns:
        bool: True if deleted.

    Raises:
        IndexError: If message not found.
    """
    """Delete a specific message synchronously."""
    key = self._get_config_key(config)
    messages = self._messages.get(key, [])
    for msg in messages:
        if msg.message_id == message_id:
            messages.remove(msg)
            logger.debug(f"Deleted message with ID {message_id} for key: {key}")
            return True
    raise IndexError(f"Message with ID {message_id} not found for config key: {key}")
get_message
get_message(config, message_id)

Retrieve a specific message synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Name Type Description
Message Message

Latest message object.

Raises:

Type Description
IndexError

If no messages found.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
def get_message(self, config: dict[str, Any], message_id: str | int) -> Message:
    """
    Retrieve a specific message synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        Message: Latest message object.

    Raises:
        IndexError: If no messages found.
    """
    """Retrieve the latest message synchronously."""
    key = self._get_config_key(config)
    messages = self._messages.get(key, [])
    for msg in messages:
        if msg.message_id == message_id:
            return msg
    raise IndexError(f"Message with ID {message_id} not found for config key: {key}")
get_state
get_state(config)

Retrieve state synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
StateT | None

StateT | None: Retrieved state or None.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
def get_state(self, config: dict[str, Any]) -> StateT | None:
    """
    Retrieve state synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        StateT | None: Retrieved state or None.
    """
    """Retrieve state synchronously."""
    key = self._get_config_key(config)
    state = self._states.get(key)
    logger.debug(f"Retrieved state for key: {key}, found: {state is not None}")
    return state
get_state_cache
get_state_cache(config)

Retrieve state cache synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
StateT | None

StateT | None: Cached state or None.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
def get_state_cache(self, config: dict[str, Any]) -> StateT | None:
    """
    Retrieve state cache synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        StateT | None: Cached state or None.
    """
    """Retrieve state cache synchronously."""
    key = self._get_config_key(config)
    cache = self._state_cache.get(key)
    logger.debug(f"Retrieved state cache for key: {key}, found: {cache is not None}")
    return cache
get_thread
get_thread(config)

Retrieve thread info synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
ThreadInfo | None

ThreadInfo | None: Thread information object or None.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
def get_thread(self, config: dict[str, Any]) -> ThreadInfo | None:
    """
    Retrieve thread info synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        ThreadInfo | None: Thread information object or None.
    """
    """Retrieve thread info synchronously."""
    key = self._get_config_key(config)
    thread = self._threads.get(key)
    logger.debug(f"Retrieved thread for key: {key}, found: {thread is not None}")
    return ThreadInfo.model_validate(thread) if thread else None
list_messages
list_messages(config, search=None, offset=None, limit=None)

List messages synchronously with optional filtering.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
search str

Search string.

None
offset int

Offset for pagination.

None
limit int

Limit for pagination.

None

Returns:

Type Description
list[Message]

list[Message]: List of message objects.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
def list_messages(
    self,
    config: dict[str, Any],
    search: str | None = None,
    offset: int | None = None,
    limit: int | None = None,
) -> list[Message]:
    """
    List messages synchronously with optional filtering.

    Args:
        config (dict): Configuration dictionary.
        search (str, optional): Search string.
        offset (int, optional): Offset for pagination.
        limit (int, optional): Limit for pagination.

    Returns:
        list[Message]: List of message objects.
    """
    key = self._get_config_key(config)
    messages = self._messages.get(key, [])

    # Apply search filter if provided
    if search:
        messages = [
            msg
            for msg in messages
            if hasattr(msg, "content") and search.lower() in str(msg.content).lower()
        ]

    # Apply offset and limit
    start = offset or 0
    end = (start + limit) if limit else None
    return messages[start:end]
list_threads
list_threads(config, search=None, offset=None, limit=None)

List all threads synchronously with optional filtering.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
search str

Search string.

None
offset int

Offset for pagination.

None
limit int

Limit for pagination.

None

Returns:

Type Description
list[ThreadInfo]

list[ThreadInfo]: List of thread information objects.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
def list_threads(
    self,
    config: dict[str, Any],
    search: str | None = None,
    offset: int | None = None,
    limit: int | None = None,
) -> list[ThreadInfo]:
    """
    List all threads synchronously with optional filtering.

    Args:
        config (dict): Configuration dictionary.
        search (str, optional): Search string.
        offset (int, optional): Offset for pagination.
        limit (int, optional): Limit for pagination.

    Returns:
        list[ThreadInfo]: List of thread information objects.
    """
    threads = list(self._threads.values())

    # Apply search filter if provided
    if search:
        threads = [
            thread
            for thread in threads
            if any(search.lower() in str(value).lower() for value in thread.values())
        ]

    # Apply offset and limit
    start = offset or 0
    end = (start + limit) if limit else None
    return [ThreadInfo.model_validate(thread) for thread in threads[start:end]]
put_messages
put_messages(config, messages, metadata=None)

Store messages synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
messages list[Message]

List of messages to store.

required
metadata dict

Additional metadata.

None

Returns:

Name Type Description
bool bool

True if stored.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
def put_messages(
    self,
    config: dict[str, Any],
    messages: list[Message],
    metadata: dict[str, Any] | None = None,
) -> bool:
    """
    Store messages synchronously.

    Args:
        config (dict): Configuration dictionary.
        messages (list[Message]): List of messages to store.
        metadata (dict, optional): Additional metadata.

    Returns:
        bool: True if stored.
    """
    key = self._get_config_key(config)
    self._messages[key].extend(messages)
    if metadata:
        self._message_metadata[key] = metadata

    logger.debug(f"Stored {len(messages)} messages for key: {key}")
    return True
put_state
put_state(config, state)

Store state synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
state StateT

State object to store.

required

Returns:

Name Type Description
StateT StateT

The stored state object.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
def put_state(self, config: dict[str, Any], state: StateT) -> StateT:
    """
    Store state synchronously.

    Args:
        config (dict): Configuration dictionary.
        state (StateT): State object to store.

    Returns:
        StateT: The stored state object.
    """
    """Store state synchronously."""
    key = self._get_config_key(config)
    # For sync methods, we'll use a simple approach without locks
    # In a real async-first system, sync methods might not be used
    self._states[key] = state
    logger.debug(f"Stored state for key: {key}")
    return state
put_state_cache
put_state_cache(config, state)

Store state cache synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
state StateT

State object to cache.

required

Returns:

Name Type Description
StateT StateT

The cached state object.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
def put_state_cache(self, config: dict[str, Any], state: StateT) -> StateT:
    """
    Store state cache synchronously.

    Args:
        config (dict): Configuration dictionary.
        state (StateT): State object to cache.

    Returns:
        StateT: The cached state object.
    """
    """Store state cache synchronously."""
    key = self._get_config_key(config)
    self._state_cache[key] = state
    logger.debug(f"Stored state cache for key: {key}")
    return state
put_thread
put_thread(config, thread_info)

Store thread info synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
thread_info ThreadInfo

Thread information object.

required

Returns:

Name Type Description
bool bool

True if stored.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
def put_thread(self, config: dict[str, Any], thread_info: ThreadInfo) -> bool:
    """
    Store thread info synchronously.

    Args:
        config (dict): Configuration dictionary.
        thread_info (ThreadInfo): Thread information object.

    Returns:
        bool: True if stored.
    """
    """Store thread info synchronously."""
    key = self._get_config_key(config)
    self._threads[key] = thread_info.model_dump()
    logger.debug(f"Stored thread info for key: {key}")
    return True
release
release()

Release resources synchronously.

Returns:

Name Type Description
bool bool

True if released.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
def release(self) -> bool:
    """
    Release resources synchronously.

    Returns:
        bool: True if released.
    """
    """Release resources synchronously."""
    self._states.clear()
    self._state_cache.clear()
    self._messages.clear()
    self._message_metadata.clear()
    self._threads.clear()
    logger.info("Released all in-memory resources")
    return True
setup
setup()

Synchronous setup method. No setup required for in-memory checkpointer.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
62
63
64
65
66
def setup(self) -> Any:
    """
    Synchronous setup method. No setup required for in-memory checkpointer.
    """
    logger.debug("InMemoryCheckpointer setup not required")
PgCheckpointer

Bases: BaseCheckpointer[StateT]

Implements a checkpointer using PostgreSQL and Redis for persistent and cached state management.

This class provides asynchronous and synchronous methods for storing, retrieving, and managing agent states, messages, and threads. PostgreSQL is used for durable storage, while Redis provides fast caching with TTL.

Features
  • Async-first design with sync fallbacks
  • Configurable ID types (string, int, bigint)
  • Connection pooling for both PostgreSQL and Redis
  • Proper error handling and resource management
  • Schema migration support

Parameters:

Name Type Description Default
postgres_dsn
str

PostgreSQL connection string.

None
pg_pool
Any

Existing asyncpg Pool instance.

None
pool_config
dict

Configuration for new pg pool creation.

None
redis_url
str

Redis connection URL.

None
redis
Any

Existing Redis instance.

None
redis_pool
Any

Existing Redis ConnectionPool.

None
redis_pool_config
dict

Configuration for new redis pool creation.

None
**kwargs

Additional configuration options: - user_id_type: Type for user_id fields ('string', 'int', 'bigint') - cache_ttl: Redis cache TTL in seconds - release_resources: Whether to release resources on cleanup

{}

Raises:

Type Description
ImportError

If required dependencies are missing.

ValueError

If required connection details are missing.

Methods:

Name Description
__init__

Initializes PgCheckpointer with PostgreSQL and Redis connections.

aclean_thread

Clean/delete a thread and all associated data.

aclear_state

Clear state from PostgreSQL and Redis cache.

adelete_message

Delete a message by ID.

aget_message

Retrieve a single message by ID.

aget_state

Retrieve state from PostgreSQL.

aget_state_cache

Get state from Redis cache, fallback to PostgreSQL if miss.

aget_thread

Get thread information.

alist_messages

List messages for a thread with optional search and pagination.

alist_threads

List threads for a user with optional search and pagination.

aput_messages

Store messages in PostgreSQL.

aput_state

Store state in PostgreSQL and optionally cache in Redis.

aput_state_cache

Cache state in Redis with TTL.

aput_thread

Create or update thread information.

arelease

Clean up connections and resources.

asetup

Asynchronous setup method. Initializes database schema.

clean_thread

Clean/delete thread synchronously.

clear_state

Clear agent state synchronously.

delete_message

Delete a specific message synchronously.

get_message

Retrieve a specific message synchronously.

get_state

Retrieve agent state synchronously.

get_state_cache

Retrieve agent state from cache synchronously.

get_thread

Retrieve thread info synchronously.

list_messages

List messages synchronously with optional filtering.

list_threads

List threads synchronously with optional filtering.

put_messages

Store messages synchronously.

put_state

Store agent state synchronously.

put_state_cache

Store agent state in cache synchronously.

put_thread

Store thread info synchronously.

release

Release resources synchronously.

setup

Synchronous setup method for checkpointer.

Attributes:

Name Type Description
cache_ttl
id_type
redis
release_resources
schema
user_id_type
Source code in pyagenity/checkpointer/pg_checkpointer.py
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
class PgCheckpointer(BaseCheckpointer[StateT]):
    """
    Implements a checkpointer using PostgreSQL and Redis for persistent and cached state management.

    This class provides asynchronous and synchronous methods for storing, retrieving, and managing
    agent states, messages, and threads. PostgreSQL is used for durable storage, while Redis
    provides fast caching with TTL.

    Features:
        - Async-first design with sync fallbacks
        - Configurable ID types (string, int, bigint)
        - Connection pooling for both PostgreSQL and Redis
        - Proper error handling and resource management
        - Schema migration support

    Args:
        postgres_dsn (str, optional): PostgreSQL connection string.
        pg_pool (Any, optional): Existing asyncpg Pool instance.
        pool_config (dict, optional): Configuration for new pg pool creation.
        redis_url (str, optional): Redis connection URL.
        redis (Any, optional): Existing Redis instance.
        redis_pool (Any, optional): Existing Redis ConnectionPool.
        redis_pool_config (dict, optional): Configuration for new redis pool creation.
        **kwargs: Additional configuration options:
            - user_id_type: Type for user_id fields ('string', 'int', 'bigint')
            - cache_ttl: Redis cache TTL in seconds
            - release_resources: Whether to release resources on cleanup

    Raises:
        ImportError: If required dependencies are missing.
        ValueError: If required connection details are missing.
    """

    def __init__(
        self,
        # postgress connection details
        postgres_dsn: str | None = None,
        pg_pool: Any | None = None,
        pool_config: dict | None = None,
        # redis connection details
        redis_url: str | None = None,
        redis: Any | None = None,
        redis_pool: Any | None = None,
        redis_pool_config: dict | None = None,
        # database schema
        schema: str = "public",
        # other configurations - combine to reduce args
        **kwargs,
    ):
        """
        Initializes PgCheckpointer with PostgreSQL and Redis connections.

        Args:
            postgres_dsn (str, optional): PostgreSQL connection string.
            pg_pool (Any, optional): Existing asyncpg Pool instance.
            pool_config (dict, optional): Configuration for new pg pool creation.
            redis_url (str, optional): Redis connection URL.
            redis (Any, optional): Existing Redis instance.
            redis_pool (Any, optional): Existing Redis ConnectionPool.
            redis_pool_config (dict, optional): Configuration for new redis pool creation.
            schema (str, optional): PostgreSQL schema name. Defaults to "public".
            **kwargs: Additional configuration options.

        Raises:
            ImportError: If required dependencies are missing.
            ValueError: If required connection details are missing.
        """
        # Check for required dependencies
        if not HAS_ASYNCPG:
            raise ImportError(
                "PgCheckpointer requires 'asyncpg' package. "
                "Install with: pip install pyagenity[pg_checkpoint]"
            )

        if not HAS_REDIS:
            raise ImportError(
                "PgCheckpointer requires 'redis' package. "
                "Install with: pip install pyagenity[pg_checkpoint]"
            )

        self.user_id_type = kwargs.get("user_id_type", "string")
        # allow explicit override via kwargs, fallback to InjectQ, then default
        self.id_type = kwargs.get(
            "id_type", InjectQ.get_instance().try_get("generated_id_type", "string")
        )
        self.cache_ttl = kwargs.get("cache_ttl", DEFAULT_CACHE_TTL)
        self.release_resources = kwargs.get("release_resources", False)

        # Validate schema name to prevent SQL injection
        if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", schema):
            raise ValueError(
                f"Invalid schema name: {schema}. Schema must match pattern ^[a-zA-Z_][a-zA-Z0-9_]*$"
            )
        self.schema = schema

        self._schema_initialized = False
        self._loop: asyncio.AbstractEventLoop | None = None

        # Store pool configuration for lazy initialization
        self._pg_pool_config = {
            "pg_pool": pg_pool,
            "postgres_dsn": postgres_dsn,
            "pool_config": pool_config or {},
        }

        # Initialize pool immediately if provided, otherwise defer
        if pg_pool is not None:
            self._pg_pool = pg_pool
        else:
            self._pg_pool = None

        # Now check and initialize connections
        if not pg_pool and not postgres_dsn:
            raise ValueError("Either postgres_dsn or pg_pool must be provided.")

        if not redis and not redis_url and not redis_pool:
            raise ValueError("Either redis_url, redis_pool or redis instance must be provided.")

        # Initialize Redis connection (synchronous)
        self.redis = self._create_redis_pool(redis, redis_pool, redis_url, redis_pool_config or {})

    def _create_redis_pool(
        self,
        redis: Any | None,
        redis_pool: Any | None,
        redis_url: str | None,
        redis_pool_config: dict,
    ) -> Any:
        """
        Create or use an existing Redis connection.

        Args:
            redis (Any, optional): Existing Redis instance.
            redis_pool (Any, optional): Existing Redis ConnectionPool.
            redis_url (str, optional): Redis connection URL.
            redis_pool_config (dict): Configuration for new redis pool creation.

        Returns:
            Redis: Redis connection instance.

        Raises:
            ValueError: If redis_url is not provided when creating a new connection.
        """
        if redis:
            return redis

        if redis_pool:
            return Redis(connection_pool=redis_pool)  # type: ignore

        # as we are creating new pool, redis_url must be provided
        # and we will release the resources if needed
        if not redis_url:
            raise ValueError("redis_url must be provided when creating new Redis connection")

        self.release_resources = True
        return Redis(
            connection_pool=ConnectionPool.from_url(  # type: ignore
                redis_url,
                **redis_pool_config,
            )
        )

    def _get_table_name(self, table: str) -> str:
        """
        Get the schema-qualified table name.

        Args:
            table (str): The base table name (e.g., 'threads', 'states', 'messages')

        Returns:
            str: The schema-qualified table name (e.g., '"public"."threads"')
        """
        # Validate table name to prevent SQL injection
        if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", table):
            raise ValueError(
                f"Invalid table name: {table}. Table must match pattern ^[a-zA-Z_][a-zA-Z0-9_]*$"
            )
        return f'"{self.schema}"."{table}"'

    def _create_pg_pool(self, pg_pool: Any, postgres_dsn: str | None, pool_config: dict) -> Any:
        """
        Create or use an existing PostgreSQL connection pool.

        Args:
            pg_pool (Any, optional): Existing asyncpg Pool instance.
            postgres_dsn (str, optional): PostgreSQL connection string.
            pool_config (dict): Configuration for new pg pool creation.

        Returns:
            Pool: PostgreSQL connection pool.
        """
        if pg_pool:
            return pg_pool
        # as we are creating new pool, postgres_dsn must be provided
        # and we will release the resources if needed
        self.release_resources = True
        return asyncpg.create_pool(dsn=postgres_dsn, **pool_config)  # type: ignore

    async def _get_pg_pool(self) -> Any:
        """
        Get PostgreSQL pool, creating it if necessary.

        Returns:
            Pool: PostgreSQL connection pool.
        """
        """Get PostgreSQL pool, creating it if necessary."""
        if self._pg_pool is None:
            config = self._pg_pool_config
            self._pg_pool = await self._create_pg_pool(
                config["pg_pool"], config["postgres_dsn"], config["pool_config"]
            )
        return self._pg_pool

    def _get_sql_type(self, type_name: str) -> str:
        """
        Get SQL type for given configuration type.

        Args:
            type_name (str): Type name ('string', 'int', 'bigint').

        Returns:
            str: Corresponding SQL type.
        """
        """Get SQL type for given configuration type."""
        return ID_TYPE_MAP.get(type_name, "VARCHAR(255)")

    def _get_json_serializer(self):
        """Get optimal JSON serializer based on FAST_JSON env var."""
        if os.environ.get("FAST_JSON", "0") == "1":
            try:
                import orjson

                return orjson.dumps
            except ImportError:
                try:
                    import msgspec  # type: ignore

                    return msgspec.json.encode
                except ImportError:
                    pass
        return json.dumps

    def _get_current_schema_version(self) -> int:
        """Return current expected schema version."""
        return 1  # increment when schema changes

    def _build_create_tables_sql(self) -> list[str]:
        """
        Build SQL statements for table creation with dynamic ID types.

        Returns:
            list[str]: List of SQL statements for table creation.
        """
        """Build SQL statements for table creation with dynamic ID types."""
        thread_id_type = self._get_sql_type(self.id_type)
        user_id_type = self._get_sql_type(self.user_id_type)
        message_id_type = self._get_sql_type(self.id_type)

        # For AUTO INCREMENT types, we need to handle primary key differently
        thread_pk = (
            "thread_id SERIAL PRIMARY KEY"
            if self.id_type == "int"
            else f"thread_id {thread_id_type} PRIMARY KEY"
        )
        message_pk = (
            "message_id SERIAL PRIMARY KEY"
            if self.id_type == "int"
            else f"message_id {message_id_type} PRIMARY KEY"
        )

        return [
            # Schema version tracking table
            f"""
            CREATE TABLE IF NOT EXISTS {self._get_table_name("schema_version")} (
                version INT PRIMARY KEY,
                applied_at TIMESTAMPTZ DEFAULT NOW()
            )
            """,
            # Create message role enum (safe for older Postgres versions)
            (
                "DO $$\n"
                "BEGIN\n"
                "    CREATE TYPE message_role AS ENUM ('user', 'assistant', 'system', 'tool');\n"
                "EXCEPTION\n"
                "    WHEN duplicate_object THEN NULL;\n"
                "END$$;"
            ),
            # Create threads table
            f"""
            CREATE TABLE IF NOT EXISTS {self._get_table_name("threads")} (
                {thread_pk},
                thread_name VARCHAR(255),
                user_id {user_id_type} NOT NULL,
                created_at TIMESTAMPTZ DEFAULT NOW(),
                updated_at TIMESTAMPTZ DEFAULT NOW(),
                meta JSONB DEFAULT '{{}}'::jsonb
            )
            """,
            # Create states table
            f"""
            CREATE TABLE IF NOT EXISTS {self._get_table_name("states")} (
                state_id SERIAL PRIMARY KEY,
                thread_id {thread_id_type} NOT NULL
                    REFERENCES {self._get_table_name("threads")}(thread_id)
                    ON DELETE CASCADE,
                state_data JSONB NOT NULL,
                created_at TIMESTAMPTZ DEFAULT NOW(),
                updated_at TIMESTAMPTZ DEFAULT NOW(),
                meta JSONB DEFAULT '{{}}'::jsonb
            )
            """,
            # Create messages table
            f"""
            CREATE TABLE IF NOT EXISTS {self._get_table_name("messages")} (
                {message_pk},
                thread_id {thread_id_type} NOT NULL
                    REFERENCES {self._get_table_name("threads")}(thread_id)
                    ON DELETE CASCADE,
                role message_role NOT NULL,
                content TEXT NOT NULL,
                tool_calls JSONB,
                tool_call_id VARCHAR(255),
                reasoning TEXT,
                created_at TIMESTAMPTZ DEFAULT NOW(),
                updated_at TIMESTAMPTZ DEFAULT NOW(),
                total_tokens INT DEFAULT 0,
                usages JSONB DEFAULT '{{}}'::jsonb,
                meta JSONB DEFAULT '{{}}'::jsonb
            )
            """,
            # Create indexes
            f"CREATE INDEX IF NOT EXISTS idx_threads_user_id ON "
            f"{self._get_table_name('threads')}(user_id)",
            f"CREATE INDEX IF NOT EXISTS idx_states_thread_id ON "
            f"{self._get_table_name('states')}(thread_id)",
            f"CREATE INDEX IF NOT EXISTS idx_messages_thread_id ON "
            f"{self._get_table_name('messages')}(thread_id)",
        ]

    async def _check_and_apply_schema_version(self, conn) -> None:
        """Check current version and update if needed."""
        try:
            # Check if schema version exists
            row = await conn.fetchrow(
                f"SELECT version FROM {self._get_table_name('schema_version')} "  # noqa: S608
                f"ORDER BY version DESC LIMIT 1"
            )
            current_version = row["version"] if row else 0
            target_version = self._get_current_schema_version()

            if current_version < target_version:
                logger.info(
                    "Upgrading schema from version %d to %d", current_version, target_version
                )
                # Insert new version
                await conn.execute(
                    f"INSERT INTO {self._get_table_name('schema_version')} (version) VALUES ($1)",  # noqa: S608
                    target_version,
                )
        except Exception as e:
            logger.debug("Schema version check failed (expected on first run): %s", e)
            # Insert initial version
            with suppress(Exception):
                await conn.execute(
                    f"INSERT INTO {self._get_table_name('schema_version')} (version) VALUES ($1)",  # noqa: S608
                    self._get_current_schema_version(),
                )

    async def _initialize_schema(self) -> None:
        """
        Initialize database schema if not already done.

        Returns:
            None
        """
        """Initialize database schema if not already done."""
        if self._schema_initialized:
            return

        logger.debug(
            "Initializing database schema with types: id_type=%s, user_id_type=%s",
            self.id_type,
            self.user_id_type,
        )

        async with (await self._get_pg_pool()).acquire() as conn:
            try:
                sql_statements = self._build_create_tables_sql()
                for sql in sql_statements:
                    logger.debug("Executing SQL: %s", sql.strip())
                    await conn.execute(sql)

                # Check and apply schema version tracking
                await self._check_and_apply_schema_version(conn)

                self._schema_initialized = True
                logger.debug("Database schema initialized successfully")
            except Exception as e:
                logger.error("Failed to initialize database schema: %s", e)
                raise

    ###########################
    #### SETUP METHODS ########
    ###########################

    async def asetup(self) -> Any:
        """
        Asynchronous setup method. Initializes database schema.

        Returns:
            Any: True if setup completed.
        """
        """Async setup method - initializes database schema."""
        logger.info(
            "Setting up PgCheckpointer (async)",
            extra={
                "id_type": self.id_type,
                "user_id_type": self.user_id_type,
                "schema": self.schema,
            },
        )
        await self._initialize_schema()
        logger.info("PgCheckpointer setup completed")
        return True

    ###########################
    #### HELPER METHODS #######
    ###########################

    def _validate_config(self, config: dict[str, Any]) -> tuple[str | int, str | int]:
        """
        Extract and validate thread_id and user_id from config.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            tuple: (thread_id, user_id)

        Raises:
            ValueError: If required fields are missing.
        """
        """Extract and validate thread_id and user_id from config."""
        thread_id = config.get("thread_id")
        user_id = config.get("user_id")
        if not user_id:
            raise ValueError("user_id must be provided in config")

        if not thread_id:
            raise ValueError("Both thread_id must be provided in config")

        return thread_id, user_id

    def _get_thread_key(
        self,
        thread_id: str | int,
        user_id: str | int,
    ) -> str:
        """
        Get Redis cache key for thread state.

        Args:
            thread_id (str|int): Thread identifier.
            user_id (str|int): User identifier.

        Returns:
            str: Redis cache key.
        """
        return f"state_cache:{thread_id}:{user_id}"

    def _serialize_state(self, state: StateT) -> str:
        """
        Serialize state to JSON string for storage.

        Args:
            state (StateT): State object.

        Returns:
            str: JSON string.
        """
        """Serialize state to JSON string for storage."""

        def enum_handler(obj):
            if isinstance(obj, Enum):
                return obj.value
            return str(obj)

        return json.dumps(state.model_dump(), default=enum_handler)

    def _serialize_state_fast(self, state: StateT) -> str:
        """
        Serialize state using fast JSON serializer if available.

        Args:
            state (StateT): State object.

        Returns:
            str: JSON string.
        """
        serializer = self._get_json_serializer()

        def enum_handler(obj):
            if isinstance(obj, Enum):
                return obj.value
            return str(obj)

        data = state.model_dump()

        # Use fast serializer if available, otherwise fall back to json.dumps with enum handling
        if serializer is json.dumps:
            return json.dumps(data, default=enum_handler)

        # Fast serializers (orjson, msgspec) may not support default handlers
        # Pre-process enums to avoid issues
        result = serializer(data)
        # Ensure we return a string (orjson returns bytes)
        return result.decode("utf-8") if isinstance(result, bytes) else str(result)

    def _deserialize_state(
        self,
        data: Any,
        state_class: type[StateT],
    ) -> StateT:
        """
        Deserialize JSON/JSONB back to state object.

        Args:
            data (Any): JSON string or dict/list.
            state_class (type): State class type.

        Returns:
            StateT: Deserialized state object.

        Raises:
            Exception: If deserialization fails.
        """
        try:
            if isinstance(data, bytes | bytearray):
                data = data.decode()
            if isinstance(data, str):
                return state_class.model_validate(json.loads(data))
            # Assume it's already a dict/list
            return state_class.model_validate(data)
        except Exception:
            # Last-resort: coerce to string and attempt parse, else raise
            if isinstance(data, str):
                return state_class.model_validate(json.loads(data))
            raise

    async def _retry_on_connection_error(
        self,
        operation,
        *args,
        max_retries=3,
        **kwargs,
    ):
        """
        Retry database operations on connection errors.

        Args:
            operation: Callable operation.
            *args: Arguments.
            max_retries (int): Maximum retries.
            **kwargs: Keyword arguments.

        Returns:
            Any: Result of operation or None.

        Raises:
            Exception: If all retries fail.
        """
        last_exception = None

        # Define exception types to catch (only if asyncpg is available)
        exceptions_to_catch: list[type[Exception]] = [ConnectionError]
        if HAS_ASYNCPG and asyncpg:
            exceptions_to_catch.extend([asyncpg.PostgresConnectionError, asyncpg.InterfaceError])

        exception_tuple = tuple(exceptions_to_catch)

        for attempt in range(max_retries):
            try:
                return await operation(*args, **kwargs)
            except exception_tuple as e:
                last_exception = e
                if attempt < max_retries - 1:
                    wait_time = 2**attempt  # exponential backoff
                    logger.warning(
                        "Database connection error on attempt %d/%d, retrying in %ds: %s",
                        attempt + 1,
                        max_retries,
                        wait_time,
                        e,
                    )
                    await asyncio.sleep(wait_time)
                    continue

                logger.error("Failed after %d attempts: %s", max_retries, e)
                break
            except Exception as e:
                # Don't retry on non-connection errors
                logger.error("Non-retryable error: %s", e)
                raise

        if last_exception:
            raise last_exception
        return None

    ###########################
    #### STATE METHODS ########
    ###########################

    async def aput_state(
        self,
        config: dict[str, Any],
        state: StateT,
    ) -> StateT:
        """
        Store state in PostgreSQL and optionally cache in Redis.

        Args:
            config (dict): Configuration dictionary.
            state (StateT): State object to store.

        Returns:
            StateT: The stored state object.

        Raises:
            StorageError: If storing fails.
        """
        # Ensure schema is initialized before accessing tables
        await self._initialize_schema()
        thread_id, user_id = self._validate_config(config)

        logger.debug("Storing state for thread_id=%s, user_id=%s", thread_id, user_id)
        metrics.counter("pg_checkpointer.save_state.attempts").inc()

        with metrics.timer("pg_checkpointer.save_state.duration"):
            try:
                # Ensure thread exists first
                await self._ensure_thread_exists(thread_id, user_id, config)

                # Store in PostgreSQL with retry logic
                state_json = self._serialize_state_fast(state)

                async def _store_state():
                    async with (await self._get_pg_pool()).acquire() as conn:
                        await conn.execute(
                            f"""
                            INSERT INTO {self._get_table_name("states")}
                                (thread_id, state_data, meta)
                            VALUES ($1, $2, $3)
                            ON CONFLICT DO NOTHING
                            """,  # noqa: S608
                            thread_id,
                            state_json,
                            json.dumps(config.get("meta", {})),
                        )

                await self._retry_on_connection_error(_store_state, max_retries=3)
                logger.debug("State stored successfully for thread_id=%s", thread_id)
                metrics.counter("pg_checkpointer.save_state.success").inc()
                return state

            except Exception as e:
                metrics.counter("pg_checkpointer.save_state.error").inc()
                logger.error("Failed to store state for thread_id=%s: %s", thread_id, e)
                if asyncpg and hasattr(asyncpg, "ConnectionDoesNotExistError"):
                    connection_errors = (
                        asyncpg.ConnectionDoesNotExistError,
                        asyncpg.InterfaceError,
                    )
                    if isinstance(e, connection_errors):
                        raise TransientStorageError(f"Connection issue storing state: {e}") from e
                raise StorageError(f"Failed to store state: {e}") from e

    async def aget_state(self, config: dict[str, Any]) -> StateT | None:
        """
        Retrieve state from PostgreSQL.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            StateT | None: Retrieved state or None.

        Raises:
            Exception: If retrieval fails.
        """
        """Retrieve state from PostgreSQL."""
        # Ensure schema is initialized before accessing tables
        await self._initialize_schema()
        thread_id, user_id = self._validate_config(config)
        state_class = config.get("state_class", AgentState)

        logger.debug("Retrieving state for thread_id=%s, user_id=%s", thread_id, user_id)

        try:

            async def _get_state():
                async with (await self._get_pg_pool()).acquire() as conn:
                    return await conn.fetchrow(
                        f"""
                        SELECT state_data FROM {self._get_table_name("states")}
                        WHERE thread_id = $1
                        ORDER BY created_at DESC
                        LIMIT 1
                        """,  # noqa: S608
                        thread_id,
                    )

            row = await self._retry_on_connection_error(_get_state, max_retries=3)

            if row:
                logger.debug("State found for thread_id=%s", thread_id)
                return self._deserialize_state(row["state_data"], state_class)

            logger.debug("No state found for thread_id=%s", thread_id)
            return None

        except Exception as e:
            logger.error("Failed to retrieve state for thread_id=%s: %s", thread_id, e)
            raise

    async def aclear_state(self, config: dict[str, Any]) -> Any:
        """
        Clear state from PostgreSQL and Redis cache.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            Any: None

        Raises:
            Exception: If clearing fails.
        """
        """Clear state from PostgreSQL and Redis cache."""
        # Ensure schema is initialized before accessing tables
        await self._initialize_schema()
        thread_id, user_id = self._validate_config(config)

        logger.debug("Clearing state for thread_id=%s, user_id=%s", thread_id, user_id)

        try:
            # Clear from PostgreSQL with retry logic
            async def _clear_state():
                async with (await self._get_pg_pool()).acquire() as conn:
                    await conn.execute(
                        f"DELETE FROM {self._get_table_name('states')} WHERE thread_id = $1",  # noqa: S608
                        thread_id,
                    )

            await self._retry_on_connection_error(_clear_state, max_retries=3)

            # Clear from Redis cache
            cache_key = self._get_thread_key(thread_id, user_id)
            await self.redis.delete(cache_key)

            logger.debug("State cleared for thread_id=%s", thread_id)

        except Exception as e:
            logger.error("Failed to clear state for thread_id=%s: %s", thread_id, e)
            raise

    async def aput_state_cache(self, config: dict[str, Any], state: StateT) -> Any | None:
        """
        Cache state in Redis with TTL.

        Args:
            config (dict): Configuration dictionary.
            state (StateT): State object to cache.

        Returns:
            Any | None: True if cached, None if failed.
        """
        """Cache state in Redis with TTL."""
        # No DB access, but keep consistent
        thread_id, user_id = self._validate_config(config)

        logger.debug("Caching state for thread_id=%s, user_id=%s", thread_id, user_id)

        try:
            cache_key = self._get_thread_key(thread_id, user_id)
            state_json = self._serialize_state(state)
            await self.redis.setex(cache_key, self.cache_ttl, state_json)
            logger.debug("State cached with key=%s, ttl=%d", cache_key, self.cache_ttl)
            return True

        except Exception as e:
            logger.error("Failed to cache state for thread_id=%s: %s", thread_id, e)
            # Don't raise - caching is optional
            return None

    async def aget_state_cache(self, config: dict[str, Any]) -> StateT | None:
        """
        Get state from Redis cache, fallback to PostgreSQL if miss.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            StateT | None: State object or None.
        """
        """Get state from Redis cache, fallback to PostgreSQL if miss."""
        # Schema might be needed if we fall back to DB
        await self._initialize_schema()
        thread_id, user_id = self._validate_config(config)
        state_class = config.get("state_class", AgentState)

        logger.debug("Getting cached state for thread_id=%s, user_id=%s", thread_id, user_id)

        try:
            # Try Redis first
            cache_key = self._get_thread_key(thread_id, user_id)
            cached_data = await self.redis.get(cache_key)

            if cached_data:
                logger.debug("Cache hit for thread_id=%s", thread_id)
                return self._deserialize_state(cached_data.decode(), state_class)

            # Cache miss - fallback to PostgreSQL
            logger.debug("Cache miss for thread_id=%s, falling back to PostgreSQL", thread_id)
            state = await self.aget_state(config)

            # Cache the result for next time
            if state:
                await self.aput_state_cache(config, state)

            return state

        except Exception as e:
            logger.error("Failed to get cached state for thread_id=%s: %s", thread_id, e)
            # Fallback to PostgreSQL on error
            return await self.aget_state(config)

    async def _ensure_thread_exists(
        self,
        thread_id: str | int,
        user_id: str | int,
        config: dict[str, Any],
    ) -> None:
        """
        Ensure thread exists in database, create if not.

        Args:
            thread_id (str|int): Thread identifier.
            user_id (str|int): User identifier.
            config (dict): Configuration dictionary.

        Returns:
            None

        Raises:
            Exception: If creation fails.
        """
        # Ensure schema is initialized before accessing tables
        await self._initialize_schema()
        try:

            async def _check_and_create_thread():
                async with (await self._get_pg_pool()).acquire() as conn:
                    exists = await conn.fetchval(
                        f"SELECT 1 FROM {self._get_table_name('threads')} "  # noqa: S608
                        f"WHERE thread_id = $1 AND user_id = $2",
                        thread_id,
                        user_id,
                    )

                    if not exists:
                        thread_name = config.get("thread_name", f"Thread {thread_id}")
                        meta = json.dumps(config.get("thread_meta", {}))
                        await conn.execute(
                            f"""
                            INSERT INTO {self._get_table_name("threads")}
                                (thread_id, thread_name, user_id, meta)
                            VALUES ($1, $2, $3, $4)
                            ON CONFLICT DO NOTHING
                            """,  # noqa: S608
                            thread_id,
                            thread_name,
                            user_id,
                            meta,
                        )
                        logger.debug("Created thread: thread_id=%s, user_id=%s", thread_id, user_id)

            await self._retry_on_connection_error(_check_and_create_thread, max_retries=3)

        except Exception as e:
            logger.error("Failed to ensure thread exists: %s", e)
            raise

    ###########################
    #### MESSAGE METHODS ######
    ###########################

    async def aput_messages(
        self,
        config: dict[str, Any],
        messages: list[Message],
        metadata: dict[str, Any] | None = None,
    ) -> Any:
        """
        Store messages in PostgreSQL.

        Args:
            config (dict): Configuration dictionary.
            messages (list[Message]): List of messages to store.
            metadata (dict, optional): Additional metadata.

        Returns:
            Any: None

        Raises:
            Exception: If storing fails.
        """
        # Ensure schema is initialized before accessing tables
        await self._initialize_schema()
        thread_id, user_id = self._validate_config(config)

        if not messages:
            logger.debug("No messages to store for thread_id=%s", thread_id)
            return

        logger.debug("Storing %d messages for thread_id=%s", len(messages), thread_id)

        try:
            # Ensure thread exists
            await self._ensure_thread_exists(thread_id, user_id, config)

            # Store messages in batch with retry logic
            async def _store_messages():
                async with (await self._get_pg_pool()).acquire() as conn, conn.transaction():
                    for message in messages:
                        # content_value = message.content
                        # if not isinstance(content_value, str):
                        #     try:
                        #         content_value = json.dumps(content_value)
                        #     except Exception:
                        #         content_value = str(content_value)
                        await conn.execute(
                            f"""
                                INSERT INTO {self._get_table_name("messages")} (
                                    message_id, thread_id, role, content, tool_calls,
                                    tool_call_id, reasoning, total_tokens, usages, meta
                                )
                                VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
                                ON CONFLICT (message_id) DO UPDATE SET
                                    content = EXCLUDED.content,
                                    reasoning = EXCLUDED.reasoning,
                                    usages = EXCLUDED.usages,
                                    updated_at = NOW()
                                """,  # noqa: S608
                            message.message_id,
                            thread_id,
                            message.role,
                            json.dumps(
                                [block.model_dump(mode="json") for block in message.content]
                            ),
                            json.dumps(message.tools_calls) if message.tools_calls else None,
                            getattr(message, "tool_call_id", None),
                            message.reasoning,
                            message.usages.total_tokens if message.usages else 0,
                            json.dumps(message.usages.model_dump()) if message.usages else None,
                            json.dumps({**(metadata or {}), **(message.metadata or {})}),
                        )

            await self._retry_on_connection_error(_store_messages, max_retries=3)
            logger.debug("Stored %d messages for thread_id=%s", len(messages), thread_id)

        except Exception as e:
            logger.error("Failed to store messages for thread_id=%s: %s", thread_id, e)
            raise

    async def aget_message(self, config: dict[str, Any], message_id: str | int) -> Message:
        """
        Retrieve a single message by ID.

        Args:
            config (dict): Configuration dictionary.
            message_id (str|int): Message identifier.

        Returns:
            Message: Retrieved message object.

        Raises:
            Exception: If retrieval fails.
        """
        """Retrieve a single message by ID."""
        # Ensure schema is initialized before accessing tables
        await self._initialize_schema()
        thread_id = config.get("thread_id")

        logger.debug("Retrieving message_id=%s for thread_id=%s", message_id, thread_id)

        try:

            async def _get_message():
                async with (await self._get_pg_pool()).acquire() as conn:
                    query = f"""
                        SELECT message_id, thread_id, role, content, tool_calls,
                               tool_call_id, reasoning, created_at, total_tokens,
                               usages, meta
                        FROM {self._get_table_name("messages")}
                        WHERE message_id = $1
                    """  # noqa: S608
                    if thread_id:
                        query += " AND thread_id = $2"
                        return await conn.fetchrow(query, message_id, thread_id)
                    return await conn.fetchrow(query, message_id)

            row = await self._retry_on_connection_error(_get_message, max_retries=3)

            if not row:
                raise ValueError(f"Message not found: {message_id}")

            return self._row_to_message(row)

        except Exception as e:
            logger.error("Failed to retrieve message_id=%s: %s", message_id, e)
            raise

    async def alist_messages(
        self,
        config: dict[str, Any],
        search: str | None = None,
        offset: int | None = None,
        limit: int | None = None,
    ) -> list[Message]:
        """
        List messages for a thread with optional search and pagination.

        Args:
            config (dict): Configuration dictionary.
            search (str, optional): Search string.
            offset (int, optional): Offset for pagination.
            limit (int, optional): Limit for pagination.

        Returns:
            list[Message]: List of message objects.

        Raises:
            Exception: If listing fails.
        """
        # Ensure schema is initialized before accessing tables
        await self._initialize_schema()
        thread_id = config.get("thread_id")

        if not thread_id:
            raise ValueError("thread_id must be provided in config")

        logger.debug("Listing messages for thread_id=%s", thread_id)

        try:

            async def _list_messages():
                async with (await self._get_pg_pool()).acquire() as conn:
                    # Build query with optional search
                    query = f"""
                        SELECT message_id, thread_id, role, content, tool_calls,
                               tool_call_id, reasoning, created_at, total_tokens,
                               usages, meta
                        FROM {self._get_table_name("messages")}
                        WHERE thread_id = $1
                    """  # noqa: S608
                    params = [thread_id]
                    param_count = 1

                    if search:
                        param_count += 1
                        query += f" AND content ILIKE ${param_count}"
                        params.append(f"%{search}%")

                    query += " ORDER BY created_at ASC"

                    if limit:
                        param_count += 1
                        query += f" LIMIT ${param_count}"
                        params.append(limit)

                    if offset:
                        param_count += 1
                        query += f" OFFSET ${param_count}"
                        params.append(offset)

                    return await conn.fetch(query, *params)

            rows = await self._retry_on_connection_error(_list_messages, max_retries=3)
            if not rows:
                rows = []
            messages = [self._row_to_message(row) for row in rows]

            logger.debug("Found %d messages for thread_id=%s", len(messages), thread_id)
            return messages

        except Exception as e:
            logger.error("Failed to list messages for thread_id=%s: %s", thread_id, e)
            raise

    async def adelete_message(
        self,
        config: dict[str, Any],
        message_id: str | int,
    ) -> Any | None:
        """
        Delete a message by ID.

        Args:
            config (dict): Configuration dictionary.
            message_id (str|int): Message identifier.

        Returns:
            Any | None: None

        Raises:
            Exception: If deletion fails.
        """
        # Ensure schema is initialized before accessing tables
        await self._initialize_schema()
        thread_id = config.get("thread_id")

        logger.debug("Deleting message_id=%s for thread_id=%s", message_id, thread_id)

        try:

            async def _delete_message():
                async with (await self._get_pg_pool()).acquire() as conn:
                    if thread_id:
                        await conn.execute(
                            f"DELETE FROM {self._get_table_name('messages')} "  # noqa: S608
                            f"WHERE message_id = $1 AND thread_id = $2",
                            message_id,
                            thread_id,
                        )
                    else:
                        await conn.execute(
                            f"DELETE FROM {self._get_table_name('messages')} WHERE message_id = $1",  # noqa: S608
                            message_id,
                        )

            await self._retry_on_connection_error(_delete_message, max_retries=3)
            logger.debug("Deleted message_id=%s", message_id)
            return None

        except Exception as e:
            logger.error("Failed to delete message_id=%s: %s", message_id, e)
            raise

    def _row_to_message(self, row) -> Message:  # noqa: PLR0912, PLR0915
        """
        Convert database row to Message object with robust JSON handling.

        Args:
            row: Database row.

        Returns:
            Message: Message object.
        """
        from pyagenity.utils.message import TokenUsages

        # Handle usages JSONB
        usages = None
        usages_raw = row["usages"]
        if usages_raw:
            try:
                usages_dict = (
                    json.loads(usages_raw)
                    if isinstance(usages_raw, str | bytes | bytearray)
                    else usages_raw
                )
                usages = TokenUsages(**usages_dict)
            except Exception:
                usages = None

        # Handle tool_calls JSONB
        tool_calls_raw = row["tool_calls"]
        if tool_calls_raw:
            try:
                tool_calls = (
                    json.loads(tool_calls_raw)
                    if isinstance(tool_calls_raw, str | bytes | bytearray)
                    else tool_calls_raw
                )
            except Exception:
                tool_calls = None
        else:
            tool_calls = None

        # Handle meta JSONB
        meta_raw = row["meta"]
        if meta_raw:
            try:
                metadata = (
                    json.loads(meta_raw)
                    if isinstance(meta_raw, str | bytes | bytearray)
                    else meta_raw
                )
            except Exception:
                metadata = {}
        else:
            metadata = {}

        # Handle content TEXT/JSONB -> list of blocks
        content_raw = row["content"]
        content_value: list[Any] = []
        if content_raw is None:
            content_value = []
        elif isinstance(content_raw, bytes | bytearray):
            try:
                parsed = json.loads(content_raw.decode())
                if isinstance(parsed, list):
                    content_value = parsed
                elif isinstance(parsed, dict):
                    content_value = [parsed]
                else:
                    content_value = [{"type": "text", "text": str(parsed), "annotations": []}]
            except Exception:
                content_value = [
                    {"type": "text", "text": content_raw.decode(errors="ignore"), "annotations": []}
                ]
        elif isinstance(content_raw, str):
            # Try JSON parse first
            try:
                parsed = json.loads(content_raw)
                if isinstance(parsed, list):
                    content_value = parsed
                elif isinstance(parsed, dict):
                    content_value = [parsed]
                else:
                    content_value = [{"type": "text", "text": content_raw, "annotations": []}]
            except Exception:
                content_value = [{"type": "text", "text": content_raw, "annotations": []}]
        elif isinstance(content_raw, list):
            content_value = content_raw
        elif isinstance(content_raw, dict):
            content_value = [content_raw]
        else:
            content_value = [{"type": "text", "text": str(content_raw), "annotations": []}]

        return Message(
            message_id=row["message_id"],
            role=row["role"],
            content=content_value,
            tools_calls=tool_calls,
            reasoning=row["reasoning"],
            timestamp=row["created_at"],
            metadata=metadata,
            usages=usages,
        )

    ###########################
    #### THREAD METHODS #######
    ###########################

    async def aput_thread(
        self,
        config: dict[str, Any],
        thread_info: ThreadInfo,
    ) -> Any | None:
        """
        Create or update thread information.

        Args:
            config (dict): Configuration dictionary.
            thread_info (ThreadInfo): Thread information object.

        Returns:
            Any | None: None

        Raises:
            Exception: If storing fails.
        """
        # Ensure schema is initialized before accessing tables
        await self._initialize_schema()
        thread_id, user_id = self._validate_config(config)

        logger.debug("Storing thread info for thread_id=%s, user_id=%s", thread_id, user_id)

        try:
            thread_name = thread_info.thread_name or f"Thread {thread_id}"
            meta = thread_info.metadata or {}
            user_id = thread_info.user_id or user_id
            meta.update(
                {
                    "run_id": thread_info.run_id,
                }
            )

            async def _put_thread():
                async with (await self._get_pg_pool()).acquire() as conn:
                    await conn.execute(
                        f"""
                        INSERT INTO {self._get_table_name("threads")}
                            (thread_id, thread_name, user_id, meta)
                        VALUES ($1, $2, $3, $4)
                        ON CONFLICT (thread_id) DO UPDATE SET
                            thread_name = EXCLUDED.thread_name,
                            meta = EXCLUDED.meta,
                            updated_at = NOW()
                        """,  # noqa: S608
                        thread_id,
                        thread_name,
                        user_id,
                        json.dumps(meta),
                    )

            await self._retry_on_connection_error(_put_thread, max_retries=3)
            logger.debug("Thread info stored for thread_id=%s", thread_id)

        except Exception as e:
            logger.error("Failed to store thread info for thread_id=%s: %s", thread_id, e)
            raise

    async def aget_thread(
        self,
        config: dict[str, Any],
    ) -> ThreadInfo | None:
        """
        Get thread information.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            ThreadInfo | None: Thread information object or None.

        Raises:
            Exception: If retrieval fails.
        """
        # Ensure schema is initialized before accessing tables
        await self._initialize_schema()
        thread_id, user_id = self._validate_config(config)

        logger.debug("Retrieving thread info for thread_id=%s, user_id=%s", thread_id, user_id)

        try:

            async def _get_thread():
                async with (await self._get_pg_pool()).acquire() as conn:
                    return await conn.fetchrow(
                        f"""
                        SELECT thread_id, thread_name, user_id, created_at, updated_at, meta
                        FROM {self._get_table_name("threads")}
                        WHERE thread_id = $1 AND user_id = $2
                        """,  # noqa: S608
                        thread_id,
                        user_id,
                    )

            row = await self._retry_on_connection_error(_get_thread, max_retries=3)

            if row:
                meta_dict = {}
                if row["meta"]:
                    meta_dict = (
                        json.loads(row["meta"]) if isinstance(row["meta"], str) else row["meta"]
                    )
                return ThreadInfo(
                    thread_id=thread_id,
                    thread_name=row["thread_name"] if row else None,
                    user_id=user_id,
                    metadata=meta_dict,
                    run_id=meta_dict.get("run_id"),
                )

            logger.debug("Thread not found for thread_id=%s, user_id=%s", thread_id, user_id)
            return None

        except Exception as e:
            logger.error("Failed to retrieve thread info for thread_id=%s: %s", thread_id, e)
            raise

    async def alist_threads(
        self,
        config: dict[str, Any],
        search: str | None = None,
        offset: int | None = None,
        limit: int | None = None,
    ) -> list[ThreadInfo]:
        """
        List threads for a user with optional search and pagination.

        Args:
            config (dict): Configuration dictionary.
            search (str, optional): Search string.
            offset (int, optional): Offset for pagination.
            limit (int, optional): Limit for pagination.

        Returns:
            list[ThreadInfo]: List of thread information objects.

        Raises:
            Exception: If listing fails.
        """
        # Ensure schema is initialized before accessing tables
        await self._initialize_schema()
        user_id = config.get("user_id")
        user_id = user_id or "test-user"

        if not user_id:
            raise ValueError("user_id must be provided in config")

        logger.debug("Listing threads for user_id=%s", user_id)

        try:

            async def _list_threads():
                async with (await self._get_pg_pool()).acquire() as conn:
                    # Build query with optional search
                    query = f"""
                        SELECT thread_id, thread_name, user_id, created_at, updated_at, meta
                        FROM {self._get_table_name("threads")}
                        WHERE user_id = $1
                    """  # noqa: S608
                    params = [user_id]
                    param_count = 1

                    if search:
                        param_count += 1
                        query += f" AND thread_name ILIKE ${param_count}"
                        params.append(f"%{search}%")

                    query += " ORDER BY updated_at DESC"

                    if limit:
                        param_count += 1
                        query += f" LIMIT ${param_count}"
                        params.append(limit)

                    if offset:
                        param_count += 1
                        query += f" OFFSET ${param_count}"
                        params.append(offset)

                    return await conn.fetch(query, *params)

            rows = await self._retry_on_connection_error(_list_threads, max_retries=3)
            if not rows:
                rows = []

            threads = []
            for row in rows:
                meta_dict = {}
                if row["meta"]:
                    meta_dict = (
                        json.loads(row["meta"]) if isinstance(row["meta"], str) else row["meta"]
                    )
                threads.append(
                    ThreadInfo(
                        thread_id=row["thread_id"],
                        thread_name=row["thread_name"],
                        user_id=row["user_id"],
                        metadata=meta_dict,
                        run_id=meta_dict.get("run_id"),
                        updated_at=row["updated_at"],
                    )
                )
            logger.debug("Found %d threads for user_id=%s", len(threads), user_id)
            return threads

        except Exception as e:
            logger.error("Failed to list threads for user_id=%s: %s", user_id, e)
            raise

    async def aclean_thread(self, config: dict[str, Any]) -> Any | None:
        """
        Clean/delete a thread and all associated data.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            Any | None: None

        Raises:
            Exception: If cleaning fails.
        """
        """Clean/delete a thread and all associated data."""
        # Ensure schema is initialized before accessing tables
        await self._initialize_schema()
        thread_id, user_id = self._validate_config(config)

        logger.debug("Cleaning thread thread_id=%s, user_id=%s", thread_id, user_id)

        try:
            # Delete thread (cascade will handle messages and states) with retry logic
            async def _clean_thread():
                async with (await self._get_pg_pool()).acquire() as conn:
                    await conn.execute(
                        f"DELETE FROM {self._get_table_name('threads')} "  # noqa: S608
                        f"WHERE thread_id = $1 AND user_id = $2",
                        thread_id,
                        user_id,
                    )

            await self._retry_on_connection_error(_clean_thread, max_retries=3)

            # Clean from Redis cache
            cache_key = self._get_thread_key(thread_id, user_id)
            await self.redis.delete(cache_key)

            logger.debug("Thread cleaned: thread_id=%s, user_id=%s", thread_id, user_id)

        except Exception as e:
            logger.error("Failed to clean thread thread_id=%s: %s", thread_id, e)
            raise

    ###########################
    #### RESOURCE CLEANUP #####
    ###########################

    async def arelease(self) -> Any | None:
        """
        Clean up connections and resources.

        Returns:
            Any | None: None
        """
        """Clean up connections and resources."""
        logger.info("Releasing PgCheckpointer resources")

        if not self.release_resources:
            logger.info("No resources to release")
            return

        errors = []

        # Close Redis connection
        try:
            if hasattr(self.redis, "aclose"):
                await self.redis.aclose()
            elif hasattr(self.redis, "close"):
                await self.redis.close()
            logger.debug("Redis connection closed")
        except Exception as e:
            logger.error("Error closing Redis connection: %s", e)
            errors.append(f"Redis: {e}")

        # Close PostgreSQL pool
        try:
            if self._pg_pool and not self._pg_pool.is_closing():
                await self._pg_pool.close()
            logger.debug("PostgreSQL pool closed")
        except Exception as e:
            logger.error("Error closing PostgreSQL pool: %s", e)
            errors.append(f"PostgreSQL: {e}")

        if errors:
            error_msg = f"Errors during resource cleanup: {'; '.join(errors)}"
            logger.warning(error_msg)
            # Don't raise - cleanup should be best effort
        else:
            logger.info("All resources released successfully")
Attributes
cache_ttl instance-attribute
cache_ttl = get('cache_ttl', DEFAULT_CACHE_TTL)
id_type instance-attribute
id_type = get('id_type', try_get('generated_id_type', 'string'))
redis instance-attribute
redis = _create_redis_pool(redis, redis_pool, redis_url, redis_pool_config or {})
release_resources instance-attribute
release_resources = get('release_resources', False)
schema instance-attribute
schema = schema
user_id_type instance-attribute
user_id_type = get('user_id_type', 'string')
Functions
__init__
__init__(postgres_dsn=None, pg_pool=None, pool_config=None, redis_url=None, redis=None, redis_pool=None, redis_pool_config=None, schema='public', **kwargs)

Initializes PgCheckpointer with PostgreSQL and Redis connections.

Parameters:

Name Type Description Default
postgres_dsn str

PostgreSQL connection string.

None
pg_pool Any

Existing asyncpg Pool instance.

None
pool_config dict

Configuration for new pg pool creation.

None
redis_url str

Redis connection URL.

None
redis Any

Existing Redis instance.

None
redis_pool Any

Existing Redis ConnectionPool.

None
redis_pool_config dict

Configuration for new redis pool creation.

None
schema str

PostgreSQL schema name. Defaults to "public".

'public'
**kwargs

Additional configuration options.

{}

Raises:

Type Description
ImportError

If required dependencies are missing.

ValueError

If required connection details are missing.

Source code in pyagenity/checkpointer/pg_checkpointer.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def __init__(
    self,
    # postgress connection details
    postgres_dsn: str | None = None,
    pg_pool: Any | None = None,
    pool_config: dict | None = None,
    # redis connection details
    redis_url: str | None = None,
    redis: Any | None = None,
    redis_pool: Any | None = None,
    redis_pool_config: dict | None = None,
    # database schema
    schema: str = "public",
    # other configurations - combine to reduce args
    **kwargs,
):
    """
    Initializes PgCheckpointer with PostgreSQL and Redis connections.

    Args:
        postgres_dsn (str, optional): PostgreSQL connection string.
        pg_pool (Any, optional): Existing asyncpg Pool instance.
        pool_config (dict, optional): Configuration for new pg pool creation.
        redis_url (str, optional): Redis connection URL.
        redis (Any, optional): Existing Redis instance.
        redis_pool (Any, optional): Existing Redis ConnectionPool.
        redis_pool_config (dict, optional): Configuration for new redis pool creation.
        schema (str, optional): PostgreSQL schema name. Defaults to "public".
        **kwargs: Additional configuration options.

    Raises:
        ImportError: If required dependencies are missing.
        ValueError: If required connection details are missing.
    """
    # Check for required dependencies
    if not HAS_ASYNCPG:
        raise ImportError(
            "PgCheckpointer requires 'asyncpg' package. "
            "Install with: pip install pyagenity[pg_checkpoint]"
        )

    if not HAS_REDIS:
        raise ImportError(
            "PgCheckpointer requires 'redis' package. "
            "Install with: pip install pyagenity[pg_checkpoint]"
        )

    self.user_id_type = kwargs.get("user_id_type", "string")
    # allow explicit override via kwargs, fallback to InjectQ, then default
    self.id_type = kwargs.get(
        "id_type", InjectQ.get_instance().try_get("generated_id_type", "string")
    )
    self.cache_ttl = kwargs.get("cache_ttl", DEFAULT_CACHE_TTL)
    self.release_resources = kwargs.get("release_resources", False)

    # Validate schema name to prevent SQL injection
    if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", schema):
        raise ValueError(
            f"Invalid schema name: {schema}. Schema must match pattern ^[a-zA-Z_][a-zA-Z0-9_]*$"
        )
    self.schema = schema

    self._schema_initialized = False
    self._loop: asyncio.AbstractEventLoop | None = None

    # Store pool configuration for lazy initialization
    self._pg_pool_config = {
        "pg_pool": pg_pool,
        "postgres_dsn": postgres_dsn,
        "pool_config": pool_config or {},
    }

    # Initialize pool immediately if provided, otherwise defer
    if pg_pool is not None:
        self._pg_pool = pg_pool
    else:
        self._pg_pool = None

    # Now check and initialize connections
    if not pg_pool and not postgres_dsn:
        raise ValueError("Either postgres_dsn or pg_pool must be provided.")

    if not redis and not redis_url and not redis_pool:
        raise ValueError("Either redis_url, redis_pool or redis instance must be provided.")

    # Initialize Redis connection (synchronous)
    self.redis = self._create_redis_pool(redis, redis_pool, redis_url, redis_pool_config or {})
aclean_thread async
aclean_thread(config)

Clean/delete a thread and all associated data.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
Any | None

Any | None: None

Raises:

Type Description
Exception

If cleaning fails.

Source code in pyagenity/checkpointer/pg_checkpointer.py
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
async def aclean_thread(self, config: dict[str, Any]) -> Any | None:
    """
    Clean/delete a thread and all associated data.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        Any | None: None

    Raises:
        Exception: If cleaning fails.
    """
    """Clean/delete a thread and all associated data."""
    # Ensure schema is initialized before accessing tables
    await self._initialize_schema()
    thread_id, user_id = self._validate_config(config)

    logger.debug("Cleaning thread thread_id=%s, user_id=%s", thread_id, user_id)

    try:
        # Delete thread (cascade will handle messages and states) with retry logic
        async def _clean_thread():
            async with (await self._get_pg_pool()).acquire() as conn:
                await conn.execute(
                    f"DELETE FROM {self._get_table_name('threads')} "  # noqa: S608
                    f"WHERE thread_id = $1 AND user_id = $2",
                    thread_id,
                    user_id,
                )

        await self._retry_on_connection_error(_clean_thread, max_retries=3)

        # Clean from Redis cache
        cache_key = self._get_thread_key(thread_id, user_id)
        await self.redis.delete(cache_key)

        logger.debug("Thread cleaned: thread_id=%s, user_id=%s", thread_id, user_id)

    except Exception as e:
        logger.error("Failed to clean thread thread_id=%s: %s", thread_id, e)
        raise
aclear_state async
aclear_state(config)

Clear state from PostgreSQL and Redis cache.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Name Type Description
Any Any

None

Raises:

Type Description
Exception

If clearing fails.

Source code in pyagenity/checkpointer/pg_checkpointer.py
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
async def aclear_state(self, config: dict[str, Any]) -> Any:
    """
    Clear state from PostgreSQL and Redis cache.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        Any: None

    Raises:
        Exception: If clearing fails.
    """
    """Clear state from PostgreSQL and Redis cache."""
    # Ensure schema is initialized before accessing tables
    await self._initialize_schema()
    thread_id, user_id = self._validate_config(config)

    logger.debug("Clearing state for thread_id=%s, user_id=%s", thread_id, user_id)

    try:
        # Clear from PostgreSQL with retry logic
        async def _clear_state():
            async with (await self._get_pg_pool()).acquire() as conn:
                await conn.execute(
                    f"DELETE FROM {self._get_table_name('states')} WHERE thread_id = $1",  # noqa: S608
                    thread_id,
                )

        await self._retry_on_connection_error(_clear_state, max_retries=3)

        # Clear from Redis cache
        cache_key = self._get_thread_key(thread_id, user_id)
        await self.redis.delete(cache_key)

        logger.debug("State cleared for thread_id=%s", thread_id)

    except Exception as e:
        logger.error("Failed to clear state for thread_id=%s: %s", thread_id, e)
        raise
adelete_message async
adelete_message(config, message_id)

Delete a message by ID.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
message_id str | int

Message identifier.

required

Returns:

Type Description
Any | None

Any | None: None

Raises:

Type Description
Exception

If deletion fails.

Source code in pyagenity/checkpointer/pg_checkpointer.py
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
async def adelete_message(
    self,
    config: dict[str, Any],
    message_id: str | int,
) -> Any | None:
    """
    Delete a message by ID.

    Args:
        config (dict): Configuration dictionary.
        message_id (str|int): Message identifier.

    Returns:
        Any | None: None

    Raises:
        Exception: If deletion fails.
    """
    # Ensure schema is initialized before accessing tables
    await self._initialize_schema()
    thread_id = config.get("thread_id")

    logger.debug("Deleting message_id=%s for thread_id=%s", message_id, thread_id)

    try:

        async def _delete_message():
            async with (await self._get_pg_pool()).acquire() as conn:
                if thread_id:
                    await conn.execute(
                        f"DELETE FROM {self._get_table_name('messages')} "  # noqa: S608
                        f"WHERE message_id = $1 AND thread_id = $2",
                        message_id,
                        thread_id,
                    )
                else:
                    await conn.execute(
                        f"DELETE FROM {self._get_table_name('messages')} WHERE message_id = $1",  # noqa: S608
                        message_id,
                    )

        await self._retry_on_connection_error(_delete_message, max_retries=3)
        logger.debug("Deleted message_id=%s", message_id)
        return None

    except Exception as e:
        logger.error("Failed to delete message_id=%s: %s", message_id, e)
        raise
aget_message async
aget_message(config, message_id)

Retrieve a single message by ID.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
message_id str | int

Message identifier.

required

Returns:

Name Type Description
Message Message

Retrieved message object.

Raises:

Type Description
Exception

If retrieval fails.

Source code in pyagenity/checkpointer/pg_checkpointer.py
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
async def aget_message(self, config: dict[str, Any], message_id: str | int) -> Message:
    """
    Retrieve a single message by ID.

    Args:
        config (dict): Configuration dictionary.
        message_id (str|int): Message identifier.

    Returns:
        Message: Retrieved message object.

    Raises:
        Exception: If retrieval fails.
    """
    """Retrieve a single message by ID."""
    # Ensure schema is initialized before accessing tables
    await self._initialize_schema()
    thread_id = config.get("thread_id")

    logger.debug("Retrieving message_id=%s for thread_id=%s", message_id, thread_id)

    try:

        async def _get_message():
            async with (await self._get_pg_pool()).acquire() as conn:
                query = f"""
                    SELECT message_id, thread_id, role, content, tool_calls,
                           tool_call_id, reasoning, created_at, total_tokens,
                           usages, meta
                    FROM {self._get_table_name("messages")}
                    WHERE message_id = $1
                """  # noqa: S608
                if thread_id:
                    query += " AND thread_id = $2"
                    return await conn.fetchrow(query, message_id, thread_id)
                return await conn.fetchrow(query, message_id)

        row = await self._retry_on_connection_error(_get_message, max_retries=3)

        if not row:
            raise ValueError(f"Message not found: {message_id}")

        return self._row_to_message(row)

    except Exception as e:
        logger.error("Failed to retrieve message_id=%s: %s", message_id, e)
        raise
aget_state async
aget_state(config)

Retrieve state from PostgreSQL.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
StateT | None

StateT | None: Retrieved state or None.

Raises:

Type Description
Exception

If retrieval fails.

Source code in pyagenity/checkpointer/pg_checkpointer.py
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
async def aget_state(self, config: dict[str, Any]) -> StateT | None:
    """
    Retrieve state from PostgreSQL.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        StateT | None: Retrieved state or None.

    Raises:
        Exception: If retrieval fails.
    """
    """Retrieve state from PostgreSQL."""
    # Ensure schema is initialized before accessing tables
    await self._initialize_schema()
    thread_id, user_id = self._validate_config(config)
    state_class = config.get("state_class", AgentState)

    logger.debug("Retrieving state for thread_id=%s, user_id=%s", thread_id, user_id)

    try:

        async def _get_state():
            async with (await self._get_pg_pool()).acquire() as conn:
                return await conn.fetchrow(
                    f"""
                    SELECT state_data FROM {self._get_table_name("states")}
                    WHERE thread_id = $1
                    ORDER BY created_at DESC
                    LIMIT 1
                    """,  # noqa: S608
                    thread_id,
                )

        row = await self._retry_on_connection_error(_get_state, max_retries=3)

        if row:
            logger.debug("State found for thread_id=%s", thread_id)
            return self._deserialize_state(row["state_data"], state_class)

        logger.debug("No state found for thread_id=%s", thread_id)
        return None

    except Exception as e:
        logger.error("Failed to retrieve state for thread_id=%s: %s", thread_id, e)
        raise
aget_state_cache async
aget_state_cache(config)

Get state from Redis cache, fallback to PostgreSQL if miss.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
StateT | None

StateT | None: State object or None.

Source code in pyagenity/checkpointer/pg_checkpointer.py
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
async def aget_state_cache(self, config: dict[str, Any]) -> StateT | None:
    """
    Get state from Redis cache, fallback to PostgreSQL if miss.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        StateT | None: State object or None.
    """
    """Get state from Redis cache, fallback to PostgreSQL if miss."""
    # Schema might be needed if we fall back to DB
    await self._initialize_schema()
    thread_id, user_id = self._validate_config(config)
    state_class = config.get("state_class", AgentState)

    logger.debug("Getting cached state for thread_id=%s, user_id=%s", thread_id, user_id)

    try:
        # Try Redis first
        cache_key = self._get_thread_key(thread_id, user_id)
        cached_data = await self.redis.get(cache_key)

        if cached_data:
            logger.debug("Cache hit for thread_id=%s", thread_id)
            return self._deserialize_state(cached_data.decode(), state_class)

        # Cache miss - fallback to PostgreSQL
        logger.debug("Cache miss for thread_id=%s, falling back to PostgreSQL", thread_id)
        state = await self.aget_state(config)

        # Cache the result for next time
        if state:
            await self.aput_state_cache(config, state)

        return state

    except Exception as e:
        logger.error("Failed to get cached state for thread_id=%s: %s", thread_id, e)
        # Fallback to PostgreSQL on error
        return await self.aget_state(config)
aget_thread async
aget_thread(config)

Get thread information.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
ThreadInfo | None

ThreadInfo | None: Thread information object or None.

Raises:

Type Description
Exception

If retrieval fails.

Source code in pyagenity/checkpointer/pg_checkpointer.py
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
async def aget_thread(
    self,
    config: dict[str, Any],
) -> ThreadInfo | None:
    """
    Get thread information.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        ThreadInfo | None: Thread information object or None.

    Raises:
        Exception: If retrieval fails.
    """
    # Ensure schema is initialized before accessing tables
    await self._initialize_schema()
    thread_id, user_id = self._validate_config(config)

    logger.debug("Retrieving thread info for thread_id=%s, user_id=%s", thread_id, user_id)

    try:

        async def _get_thread():
            async with (await self._get_pg_pool()).acquire() as conn:
                return await conn.fetchrow(
                    f"""
                    SELECT thread_id, thread_name, user_id, created_at, updated_at, meta
                    FROM {self._get_table_name("threads")}
                    WHERE thread_id = $1 AND user_id = $2
                    """,  # noqa: S608
                    thread_id,
                    user_id,
                )

        row = await self._retry_on_connection_error(_get_thread, max_retries=3)

        if row:
            meta_dict = {}
            if row["meta"]:
                meta_dict = (
                    json.loads(row["meta"]) if isinstance(row["meta"], str) else row["meta"]
                )
            return ThreadInfo(
                thread_id=thread_id,
                thread_name=row["thread_name"] if row else None,
                user_id=user_id,
                metadata=meta_dict,
                run_id=meta_dict.get("run_id"),
            )

        logger.debug("Thread not found for thread_id=%s, user_id=%s", thread_id, user_id)
        return None

    except Exception as e:
        logger.error("Failed to retrieve thread info for thread_id=%s: %s", thread_id, e)
        raise
alist_messages async
alist_messages(config, search=None, offset=None, limit=None)

List messages for a thread with optional search and pagination.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
search str

Search string.

None
offset int

Offset for pagination.

None
limit int

Limit for pagination.

None

Returns:

Type Description
list[Message]

list[Message]: List of message objects.

Raises:

Type Description
Exception

If listing fails.

Source code in pyagenity/checkpointer/pg_checkpointer.py
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
async def alist_messages(
    self,
    config: dict[str, Any],
    search: str | None = None,
    offset: int | None = None,
    limit: int | None = None,
) -> list[Message]:
    """
    List messages for a thread with optional search and pagination.

    Args:
        config (dict): Configuration dictionary.
        search (str, optional): Search string.
        offset (int, optional): Offset for pagination.
        limit (int, optional): Limit for pagination.

    Returns:
        list[Message]: List of message objects.

    Raises:
        Exception: If listing fails.
    """
    # Ensure schema is initialized before accessing tables
    await self._initialize_schema()
    thread_id = config.get("thread_id")

    if not thread_id:
        raise ValueError("thread_id must be provided in config")

    logger.debug("Listing messages for thread_id=%s", thread_id)

    try:

        async def _list_messages():
            async with (await self._get_pg_pool()).acquire() as conn:
                # Build query with optional search
                query = f"""
                    SELECT message_id, thread_id, role, content, tool_calls,
                           tool_call_id, reasoning, created_at, total_tokens,
                           usages, meta
                    FROM {self._get_table_name("messages")}
                    WHERE thread_id = $1
                """  # noqa: S608
                params = [thread_id]
                param_count = 1

                if search:
                    param_count += 1
                    query += f" AND content ILIKE ${param_count}"
                    params.append(f"%{search}%")

                query += " ORDER BY created_at ASC"

                if limit:
                    param_count += 1
                    query += f" LIMIT ${param_count}"
                    params.append(limit)

                if offset:
                    param_count += 1
                    query += f" OFFSET ${param_count}"
                    params.append(offset)

                return await conn.fetch(query, *params)

        rows = await self._retry_on_connection_error(_list_messages, max_retries=3)
        if not rows:
            rows = []
        messages = [self._row_to_message(row) for row in rows]

        logger.debug("Found %d messages for thread_id=%s", len(messages), thread_id)
        return messages

    except Exception as e:
        logger.error("Failed to list messages for thread_id=%s: %s", thread_id, e)
        raise
alist_threads async
alist_threads(config, search=None, offset=None, limit=None)

List threads for a user with optional search and pagination.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
search str

Search string.

None
offset int

Offset for pagination.

None
limit int

Limit for pagination.

None

Returns:

Type Description
list[ThreadInfo]

list[ThreadInfo]: List of thread information objects.

Raises:

Type Description
Exception

If listing fails.

Source code in pyagenity/checkpointer/pg_checkpointer.py
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
async def alist_threads(
    self,
    config: dict[str, Any],
    search: str | None = None,
    offset: int | None = None,
    limit: int | None = None,
) -> list[ThreadInfo]:
    """
    List threads for a user with optional search and pagination.

    Args:
        config (dict): Configuration dictionary.
        search (str, optional): Search string.
        offset (int, optional): Offset for pagination.
        limit (int, optional): Limit for pagination.

    Returns:
        list[ThreadInfo]: List of thread information objects.

    Raises:
        Exception: If listing fails.
    """
    # Ensure schema is initialized before accessing tables
    await self._initialize_schema()
    user_id = config.get("user_id")
    user_id = user_id or "test-user"

    if not user_id:
        raise ValueError("user_id must be provided in config")

    logger.debug("Listing threads for user_id=%s", user_id)

    try:

        async def _list_threads():
            async with (await self._get_pg_pool()).acquire() as conn:
                # Build query with optional search
                query = f"""
                    SELECT thread_id, thread_name, user_id, created_at, updated_at, meta
                    FROM {self._get_table_name("threads")}
                    WHERE user_id = $1
                """  # noqa: S608
                params = [user_id]
                param_count = 1

                if search:
                    param_count += 1
                    query += f" AND thread_name ILIKE ${param_count}"
                    params.append(f"%{search}%")

                query += " ORDER BY updated_at DESC"

                if limit:
                    param_count += 1
                    query += f" LIMIT ${param_count}"
                    params.append(limit)

                if offset:
                    param_count += 1
                    query += f" OFFSET ${param_count}"
                    params.append(offset)

                return await conn.fetch(query, *params)

        rows = await self._retry_on_connection_error(_list_threads, max_retries=3)
        if not rows:
            rows = []

        threads = []
        for row in rows:
            meta_dict = {}
            if row["meta"]:
                meta_dict = (
                    json.loads(row["meta"]) if isinstance(row["meta"], str) else row["meta"]
                )
            threads.append(
                ThreadInfo(
                    thread_id=row["thread_id"],
                    thread_name=row["thread_name"],
                    user_id=row["user_id"],
                    metadata=meta_dict,
                    run_id=meta_dict.get("run_id"),
                    updated_at=row["updated_at"],
                )
            )
        logger.debug("Found %d threads for user_id=%s", len(threads), user_id)
        return threads

    except Exception as e:
        logger.error("Failed to list threads for user_id=%s: %s", user_id, e)
        raise
aput_messages async
aput_messages(config, messages, metadata=None)

Store messages in PostgreSQL.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
messages list[Message]

List of messages to store.

required
metadata dict

Additional metadata.

None

Returns:

Name Type Description
Any Any

None

Raises:

Type Description
Exception

If storing fails.

Source code in pyagenity/checkpointer/pg_checkpointer.py
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
async def aput_messages(
    self,
    config: dict[str, Any],
    messages: list[Message],
    metadata: dict[str, Any] | None = None,
) -> Any:
    """
    Store messages in PostgreSQL.

    Args:
        config (dict): Configuration dictionary.
        messages (list[Message]): List of messages to store.
        metadata (dict, optional): Additional metadata.

    Returns:
        Any: None

    Raises:
        Exception: If storing fails.
    """
    # Ensure schema is initialized before accessing tables
    await self._initialize_schema()
    thread_id, user_id = self._validate_config(config)

    if not messages:
        logger.debug("No messages to store for thread_id=%s", thread_id)
        return

    logger.debug("Storing %d messages for thread_id=%s", len(messages), thread_id)

    try:
        # Ensure thread exists
        await self._ensure_thread_exists(thread_id, user_id, config)

        # Store messages in batch with retry logic
        async def _store_messages():
            async with (await self._get_pg_pool()).acquire() as conn, conn.transaction():
                for message in messages:
                    # content_value = message.content
                    # if not isinstance(content_value, str):
                    #     try:
                    #         content_value = json.dumps(content_value)
                    #     except Exception:
                    #         content_value = str(content_value)
                    await conn.execute(
                        f"""
                            INSERT INTO {self._get_table_name("messages")} (
                                message_id, thread_id, role, content, tool_calls,
                                tool_call_id, reasoning, total_tokens, usages, meta
                            )
                            VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
                            ON CONFLICT (message_id) DO UPDATE SET
                                content = EXCLUDED.content,
                                reasoning = EXCLUDED.reasoning,
                                usages = EXCLUDED.usages,
                                updated_at = NOW()
                            """,  # noqa: S608
                        message.message_id,
                        thread_id,
                        message.role,
                        json.dumps(
                            [block.model_dump(mode="json") for block in message.content]
                        ),
                        json.dumps(message.tools_calls) if message.tools_calls else None,
                        getattr(message, "tool_call_id", None),
                        message.reasoning,
                        message.usages.total_tokens if message.usages else 0,
                        json.dumps(message.usages.model_dump()) if message.usages else None,
                        json.dumps({**(metadata or {}), **(message.metadata or {})}),
                    )

        await self._retry_on_connection_error(_store_messages, max_retries=3)
        logger.debug("Stored %d messages for thread_id=%s", len(messages), thread_id)

    except Exception as e:
        logger.error("Failed to store messages for thread_id=%s: %s", thread_id, e)
        raise
aput_state async
aput_state(config, state)

Store state in PostgreSQL and optionally cache in Redis.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
state StateT

State object to store.

required

Returns:

Name Type Description
StateT StateT

The stored state object.

Raises:

Type Description
StorageError

If storing fails.

Source code in pyagenity/checkpointer/pg_checkpointer.py
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
async def aput_state(
    self,
    config: dict[str, Any],
    state: StateT,
) -> StateT:
    """
    Store state in PostgreSQL and optionally cache in Redis.

    Args:
        config (dict): Configuration dictionary.
        state (StateT): State object to store.

    Returns:
        StateT: The stored state object.

    Raises:
        StorageError: If storing fails.
    """
    # Ensure schema is initialized before accessing tables
    await self._initialize_schema()
    thread_id, user_id = self._validate_config(config)

    logger.debug("Storing state for thread_id=%s, user_id=%s", thread_id, user_id)
    metrics.counter("pg_checkpointer.save_state.attempts").inc()

    with metrics.timer("pg_checkpointer.save_state.duration"):
        try:
            # Ensure thread exists first
            await self._ensure_thread_exists(thread_id, user_id, config)

            # Store in PostgreSQL with retry logic
            state_json = self._serialize_state_fast(state)

            async def _store_state():
                async with (await self._get_pg_pool()).acquire() as conn:
                    await conn.execute(
                        f"""
                        INSERT INTO {self._get_table_name("states")}
                            (thread_id, state_data, meta)
                        VALUES ($1, $2, $3)
                        ON CONFLICT DO NOTHING
                        """,  # noqa: S608
                        thread_id,
                        state_json,
                        json.dumps(config.get("meta", {})),
                    )

            await self._retry_on_connection_error(_store_state, max_retries=3)
            logger.debug("State stored successfully for thread_id=%s", thread_id)
            metrics.counter("pg_checkpointer.save_state.success").inc()
            return state

        except Exception as e:
            metrics.counter("pg_checkpointer.save_state.error").inc()
            logger.error("Failed to store state for thread_id=%s: %s", thread_id, e)
            if asyncpg and hasattr(asyncpg, "ConnectionDoesNotExistError"):
                connection_errors = (
                    asyncpg.ConnectionDoesNotExistError,
                    asyncpg.InterfaceError,
                )
                if isinstance(e, connection_errors):
                    raise TransientStorageError(f"Connection issue storing state: {e}") from e
            raise StorageError(f"Failed to store state: {e}") from e
aput_state_cache async
aput_state_cache(config, state)

Cache state in Redis with TTL.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
state StateT

State object to cache.

required

Returns:

Type Description
Any | None

Any | None: True if cached, None if failed.

Source code in pyagenity/checkpointer/pg_checkpointer.py
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
async def aput_state_cache(self, config: dict[str, Any], state: StateT) -> Any | None:
    """
    Cache state in Redis with TTL.

    Args:
        config (dict): Configuration dictionary.
        state (StateT): State object to cache.

    Returns:
        Any | None: True if cached, None if failed.
    """
    """Cache state in Redis with TTL."""
    # No DB access, but keep consistent
    thread_id, user_id = self._validate_config(config)

    logger.debug("Caching state for thread_id=%s, user_id=%s", thread_id, user_id)

    try:
        cache_key = self._get_thread_key(thread_id, user_id)
        state_json = self._serialize_state(state)
        await self.redis.setex(cache_key, self.cache_ttl, state_json)
        logger.debug("State cached with key=%s, ttl=%d", cache_key, self.cache_ttl)
        return True

    except Exception as e:
        logger.error("Failed to cache state for thread_id=%s: %s", thread_id, e)
        # Don't raise - caching is optional
        return None
aput_thread async
aput_thread(config, thread_info)

Create or update thread information.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
thread_info ThreadInfo

Thread information object.

required

Returns:

Type Description
Any | None

Any | None: None

Raises:

Type Description
Exception

If storing fails.

Source code in pyagenity/checkpointer/pg_checkpointer.py
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
async def aput_thread(
    self,
    config: dict[str, Any],
    thread_info: ThreadInfo,
) -> Any | None:
    """
    Create or update thread information.

    Args:
        config (dict): Configuration dictionary.
        thread_info (ThreadInfo): Thread information object.

    Returns:
        Any | None: None

    Raises:
        Exception: If storing fails.
    """
    # Ensure schema is initialized before accessing tables
    await self._initialize_schema()
    thread_id, user_id = self._validate_config(config)

    logger.debug("Storing thread info for thread_id=%s, user_id=%s", thread_id, user_id)

    try:
        thread_name = thread_info.thread_name or f"Thread {thread_id}"
        meta = thread_info.metadata or {}
        user_id = thread_info.user_id or user_id
        meta.update(
            {
                "run_id": thread_info.run_id,
            }
        )

        async def _put_thread():
            async with (await self._get_pg_pool()).acquire() as conn:
                await conn.execute(
                    f"""
                    INSERT INTO {self._get_table_name("threads")}
                        (thread_id, thread_name, user_id, meta)
                    VALUES ($1, $2, $3, $4)
                    ON CONFLICT (thread_id) DO UPDATE SET
                        thread_name = EXCLUDED.thread_name,
                        meta = EXCLUDED.meta,
                        updated_at = NOW()
                    """,  # noqa: S608
                    thread_id,
                    thread_name,
                    user_id,
                    json.dumps(meta),
                )

        await self._retry_on_connection_error(_put_thread, max_retries=3)
        logger.debug("Thread info stored for thread_id=%s", thread_id)

    except Exception as e:
        logger.error("Failed to store thread info for thread_id=%s: %s", thread_id, e)
        raise
arelease async
arelease()

Clean up connections and resources.

Returns:

Type Description
Any | None

Any | None: None

Source code in pyagenity/checkpointer/pg_checkpointer.py
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
async def arelease(self) -> Any | None:
    """
    Clean up connections and resources.

    Returns:
        Any | None: None
    """
    """Clean up connections and resources."""
    logger.info("Releasing PgCheckpointer resources")

    if not self.release_resources:
        logger.info("No resources to release")
        return

    errors = []

    # Close Redis connection
    try:
        if hasattr(self.redis, "aclose"):
            await self.redis.aclose()
        elif hasattr(self.redis, "close"):
            await self.redis.close()
        logger.debug("Redis connection closed")
    except Exception as e:
        logger.error("Error closing Redis connection: %s", e)
        errors.append(f"Redis: {e}")

    # Close PostgreSQL pool
    try:
        if self._pg_pool and not self._pg_pool.is_closing():
            await self._pg_pool.close()
        logger.debug("PostgreSQL pool closed")
    except Exception as e:
        logger.error("Error closing PostgreSQL pool: %s", e)
        errors.append(f"PostgreSQL: {e}")

    if errors:
        error_msg = f"Errors during resource cleanup: {'; '.join(errors)}"
        logger.warning(error_msg)
        # Don't raise - cleanup should be best effort
    else:
        logger.info("All resources released successfully")
asetup async
asetup()

Asynchronous setup method. Initializes database schema.

Returns:

Name Type Description
Any Any

True if setup completed.

Source code in pyagenity/checkpointer/pg_checkpointer.py
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
async def asetup(self) -> Any:
    """
    Asynchronous setup method. Initializes database schema.

    Returns:
        Any: True if setup completed.
    """
    """Async setup method - initializes database schema."""
    logger.info(
        "Setting up PgCheckpointer (async)",
        extra={
            "id_type": self.id_type,
            "user_id_type": self.user_id_type,
            "schema": self.schema,
        },
    )
    await self._initialize_schema()
    logger.info("PgCheckpointer setup completed")
    return True
clean_thread
clean_thread(config)

Clean/delete thread synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
458
459
460
461
462
463
464
465
466
467
468
def clean_thread(self, config: dict[str, Any]) -> Any | None:
    """
    Clean/delete thread synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        Any | None: Implementation-defined result.
    """
    return run_coroutine(self.aclean_thread(config))
clear_state
clear_state(config)

Clear agent state synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Name Type Description
Any Any

Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
159
160
161
162
163
164
165
166
167
168
169
def clear_state(self, config: dict[str, Any]) -> Any:
    """
    Clear agent state synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        Any: Implementation-defined result.
    """
    return run_coroutine(self.aclear_state(config))
delete_message
delete_message(config, message_id)

Delete a specific message synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
message_id str | int

Message identifier.

required

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
324
325
326
327
328
329
330
331
332
333
334
335
def delete_message(self, config: dict[str, Any], message_id: str | int) -> Any | None:
    """
    Delete a specific message synchronously.

    Args:
        config (dict): Configuration dictionary.
        message_id (str|int): Message identifier.

    Returns:
        Any | None: Implementation-defined result.
    """
    return run_coroutine(self.adelete_message(config, message_id))
get_message
get_message(config, message_id)

Retrieve a specific message synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Name Type Description
Message Message

Retrieved message object.

Source code in pyagenity/checkpointer/base_checkpointer.py
291
292
293
294
295
296
297
298
299
300
301
def get_message(self, config: dict[str, Any], message_id: str | int) -> Message:
    """
    Retrieve a specific message synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        Message: Retrieved message object.
    """
    return run_coroutine(self.aget_message(config, message_id))
get_state
get_state(config)

Retrieve agent state synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
StateT | None

StateT | None: Retrieved state or None.

Source code in pyagenity/checkpointer/base_checkpointer.py
147
148
149
150
151
152
153
154
155
156
157
def get_state(self, config: dict[str, Any]) -> StateT | None:
    """
    Retrieve agent state synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        StateT | None: Retrieved state or None.
    """
    return run_coroutine(self.aget_state(config))
get_state_cache
get_state_cache(config)

Retrieve agent state from cache synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
StateT | None

StateT | None: Cached state or None.

Source code in pyagenity/checkpointer/base_checkpointer.py
184
185
186
187
188
189
190
191
192
193
194
def get_state_cache(self, config: dict[str, Any]) -> StateT | None:
    """
    Retrieve agent state from cache synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        StateT | None: Cached state or None.
    """
    return run_coroutine(self.aget_state_cache(config))
get_thread
get_thread(config)

Retrieve thread info synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
ThreadInfo | None

ThreadInfo | None: Thread information object or None.

Source code in pyagenity/checkpointer/base_checkpointer.py
425
426
427
428
429
430
431
432
433
434
435
def get_thread(self, config: dict[str, Any]) -> ThreadInfo | None:
    """
    Retrieve thread info synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        ThreadInfo | None: Thread information object or None.
    """
    return run_coroutine(self.aget_thread(config))
list_messages
list_messages(config, search=None, offset=None, limit=None)

List messages synchronously with optional filtering.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
search str

Search string.

None
offset int

Offset for pagination.

None
limit int

Limit for pagination.

None

Returns:

Type Description
list[Message]

list[Message]: List of message objects.

Source code in pyagenity/checkpointer/base_checkpointer.py
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
def list_messages(
    self,
    config: dict[str, Any],
    search: str | None = None,
    offset: int | None = None,
    limit: int | None = None,
) -> list[Message]:
    """
    List messages synchronously with optional filtering.

    Args:
        config (dict): Configuration dictionary.
        search (str, optional): Search string.
        offset (int, optional): Offset for pagination.
        limit (int, optional): Limit for pagination.

    Returns:
        list[Message]: List of message objects.
    """
    return run_coroutine(self.alist_messages(config, search, offset, limit))
list_threads
list_threads(config, search=None, offset=None, limit=None)

List threads synchronously with optional filtering.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
search str

Search string.

None
offset int

Offset for pagination.

None
limit int

Limit for pagination.

None

Returns:

Type Description
list[ThreadInfo]

list[ThreadInfo]: List of thread information objects.

Source code in pyagenity/checkpointer/base_checkpointer.py
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
def list_threads(
    self,
    config: dict[str, Any],
    search: str | None = None,
    offset: int | None = None,
    limit: int | None = None,
) -> list[ThreadInfo]:
    """
    List threads synchronously with optional filtering.

    Args:
        config (dict): Configuration dictionary.
        search (str, optional): Search string.
        offset (int, optional): Offset for pagination.
        limit (int, optional): Limit for pagination.

    Returns:
        list[ThreadInfo]: List of thread information objects.
    """
    return run_coroutine(self.alist_threads(config, search, offset, limit))
put_messages
put_messages(config, messages, metadata=None)

Store messages synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
messages list[Message]

List of messages to store.

required
metadata dict

Additional metadata.

None

Returns:

Name Type Description
Any Any

Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
def put_messages(
    self,
    config: dict[str, Any],
    messages: list[Message],
    metadata: dict[str, Any] | None = None,
) -> Any:
    """
    Store messages synchronously.

    Args:
        config (dict): Configuration dictionary.
        messages (list[Message]): List of messages to store.
        metadata (dict, optional): Additional metadata.

    Returns:
        Any: Implementation-defined result.
    """
    return run_coroutine(self.aput_messages(config, messages, metadata))
put_state
put_state(config, state)

Store agent state synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
state StateT

State object to store.

required

Returns:

Name Type Description
StateT StateT

The stored state object.

Source code in pyagenity/checkpointer/base_checkpointer.py
134
135
136
137
138
139
140
141
142
143
144
145
def put_state(self, config: dict[str, Any], state: StateT) -> StateT:
    """
    Store agent state synchronously.

    Args:
        config (dict): Configuration dictionary.
        state (StateT): State object to store.

    Returns:
        StateT: The stored state object.
    """
    return run_coroutine(self.aput_state(config, state))
put_state_cache
put_state_cache(config, state)

Store agent state in cache synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
state StateT

State object to cache.

required

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
171
172
173
174
175
176
177
178
179
180
181
182
def put_state_cache(self, config: dict[str, Any], state: StateT) -> Any | None:
    """
    Store agent state in cache synchronously.

    Args:
        config (dict): Configuration dictionary.
        state (StateT): State object to cache.

    Returns:
        Any | None: Implementation-defined result.
    """
    return run_coroutine(self.aput_state_cache(config, state))
put_thread
put_thread(config, thread_info)

Store thread info synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
thread_info ThreadInfo

Thread information object.

required

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
412
413
414
415
416
417
418
419
420
421
422
423
def put_thread(self, config: dict[str, Any], thread_info: ThreadInfo) -> Any | None:
    """
    Store thread info synchronously.

    Args:
        config (dict): Configuration dictionary.
        thread_info (ThreadInfo): Thread information object.

    Returns:
        Any | None: Implementation-defined result.
    """
    return run_coroutine(self.aput_thread(config, thread_info))
release
release()

Release resources synchronously.

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
473
474
475
476
477
478
479
480
def release(self) -> Any | None:
    """
    Release resources synchronously.

    Returns:
        Any | None: Implementation-defined result.
    """
    return run_coroutine(self.arelease())
setup
setup()

Synchronous setup method for checkpointer.

Returns:

Name Type Description
Any Any

Implementation-defined setup result.

Source code in pyagenity/checkpointer/base_checkpointer.py
42
43
44
45
46
47
48
49
def setup(self) -> Any:
    """
    Synchronous setup method for checkpointer.

    Returns:
        Any: Implementation-defined setup result.
    """
    return run_coroutine(self.asetup())

Modules

base_checkpointer

Classes:

Name Description
BaseCheckpointer

Abstract base class for checkpointing agent state, messages, and threads.

Attributes:

Name Type Description
StateT
logger
Attributes
StateT module-attribute
StateT = TypeVar('StateT', bound='AgentState')
logger module-attribute
logger = getLogger(__name__)
Classes
BaseCheckpointer

Bases: ABC

Abstract base class for checkpointing agent state, messages, and threads.

This class defines the contract for all checkpointer implementations, supporting both async and sync methods. Subclasses should implement async methods for optimal performance. Sync methods are provided for compatibility.

Usage
  • Async-first design: subclasses should implement async def methods.
  • If a subclass provides only a sync def, it will be executed in a worker thread automatically using asyncio.run.
  • Callers always use the async APIs (await cp.put_state(...), etc.).

Class Type Parameters:

Name Bound or Constraints Description Default
StateT AgentState

Type of agent state (must inherit from AgentState).

required

Methods:

Name Description
aclean_thread

Clean/delete thread asynchronously.

aclear_state

Clear agent state asynchronously.

adelete_message

Delete a specific message asynchronously.

aget_message

Retrieve a specific message asynchronously.

aget_state

Retrieve agent state asynchronously.

aget_state_cache

Retrieve agent state from cache asynchronously.

aget_thread

Retrieve thread info asynchronously.

alist_messages

List messages asynchronously with optional filtering.

alist_threads

List threads asynchronously with optional filtering.

aput_messages

Store messages asynchronously.

aput_state

Store agent state asynchronously.

aput_state_cache

Store agent state in cache asynchronously.

aput_thread

Store thread info asynchronously.

arelease

Release resources asynchronously.

asetup

Asynchronous setup method for checkpointer.

clean_thread

Clean/delete thread synchronously.

clear_state

Clear agent state synchronously.

delete_message

Delete a specific message synchronously.

get_message

Retrieve a specific message synchronously.

get_state

Retrieve agent state synchronously.

get_state_cache

Retrieve agent state from cache synchronously.

get_thread

Retrieve thread info synchronously.

list_messages

List messages synchronously with optional filtering.

list_threads

List threads synchronously with optional filtering.

put_messages

Store messages synchronously.

put_state

Store agent state synchronously.

put_state_cache

Store agent state in cache synchronously.

put_thread

Store thread info synchronously.

release

Release resources synchronously.

setup

Synchronous setup method for checkpointer.

Source code in pyagenity/checkpointer/base_checkpointer.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
class BaseCheckpointer[StateT: AgentState](ABC):
    """
    Abstract base class for checkpointing agent state, messages, and threads.

    This class defines the contract for all checkpointer implementations, supporting both
    async and sync methods.
    Subclasses should implement async methods for optimal performance.
    Sync methods are provided for compatibility.

    Usage:
        - Async-first design: subclasses should implement `async def` methods.
        - If a subclass provides only a sync `def`, it will be executed in a worker thread
            automatically using `asyncio.run`.
        - Callers always use the async APIs (`await cp.put_state(...)`, etc.).

    Type Args:
        StateT: Type of agent state (must inherit from AgentState).
    """

    ###########################
    #### SETUP ################
    ###########################
    def setup(self) -> Any:
        """
        Synchronous setup method for checkpointer.

        Returns:
            Any: Implementation-defined setup result.
        """
        return run_coroutine(self.asetup())

    @abstractmethod
    async def asetup(self) -> Any:
        """
        Asynchronous setup method for checkpointer.

        Returns:
            Any: Implementation-defined setup result.
        """
        raise NotImplementedError

    # -------------------------
    # State methods Async
    # -------------------------
    @abstractmethod
    async def aput_state(self, config: dict[str, Any], state: StateT) -> StateT:
        """
        Store agent state asynchronously.

        Args:
            config (dict): Configuration dictionary.
            state (StateT): State object to store.

        Returns:
            StateT: The stored state object.
        """
        raise NotImplementedError

    @abstractmethod
    async def aget_state(self, config: dict[str, Any]) -> StateT | None:
        """
        Retrieve agent state asynchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            StateT | None: Retrieved state or None.
        """
        raise NotImplementedError

    @abstractmethod
    async def aclear_state(self, config: dict[str, Any]) -> Any:
        """
        Clear agent state asynchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            Any: Implementation-defined result.
        """
        raise NotImplementedError

    @abstractmethod
    async def aput_state_cache(self, config: dict[str, Any], state: StateT) -> Any | None:
        """
        Store agent state in cache asynchronously.

        Args:
            config (dict): Configuration dictionary.
            state (StateT): State object to cache.

        Returns:
            Any | None: Implementation-defined result.
        """
        raise NotImplementedError

    @abstractmethod
    async def aget_state_cache(self, config: dict[str, Any]) -> StateT | None:
        """
        Retrieve agent state from cache asynchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            StateT | None: Cached state or None.
        """
        raise NotImplementedError

    # -------------------------
    # State methods Sync
    # -------------------------
    def put_state(self, config: dict[str, Any], state: StateT) -> StateT:
        """
        Store agent state synchronously.

        Args:
            config (dict): Configuration dictionary.
            state (StateT): State object to store.

        Returns:
            StateT: The stored state object.
        """
        return run_coroutine(self.aput_state(config, state))

    def get_state(self, config: dict[str, Any]) -> StateT | None:
        """
        Retrieve agent state synchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            StateT | None: Retrieved state or None.
        """
        return run_coroutine(self.aget_state(config))

    def clear_state(self, config: dict[str, Any]) -> Any:
        """
        Clear agent state synchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            Any: Implementation-defined result.
        """
        return run_coroutine(self.aclear_state(config))

    def put_state_cache(self, config: dict[str, Any], state: StateT) -> Any | None:
        """
        Store agent state in cache synchronously.

        Args:
            config (dict): Configuration dictionary.
            state (StateT): State object to cache.

        Returns:
            Any | None: Implementation-defined result.
        """
        return run_coroutine(self.aput_state_cache(config, state))

    def get_state_cache(self, config: dict[str, Any]) -> StateT | None:
        """
        Retrieve agent state from cache synchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            StateT | None: Cached state or None.
        """
        return run_coroutine(self.aget_state_cache(config))

    # -------------------------
    # Message methods async
    # -------------------------
    @abstractmethod
    async def aput_messages(
        self,
        config: dict[str, Any],
        messages: list[Message],
        metadata: dict[str, Any] | None = None,
    ) -> Any:
        """
        Store messages asynchronously.

        Args:
            config (dict): Configuration dictionary.
            messages (list[Message]): List of messages to store.
            metadata (dict, optional): Additional metadata.

        Returns:
            Any: Implementation-defined result.
        """
        raise NotImplementedError

    @abstractmethod
    async def aget_message(self, config: dict[str, Any], message_id: str | int) -> Message:
        """
        Retrieve a specific message asynchronously.

        Args:
            config (dict): Configuration dictionary.
            message_id (str|int): Message identifier.

        Returns:
            Message: Retrieved message object.
        """
        raise NotImplementedError

    @abstractmethod
    async def alist_messages(
        self,
        config: dict[str, Any],
        search: str | None = None,
        offset: int | None = None,
        limit: int | None = None,
    ) -> list[Message]:
        """
        List messages asynchronously with optional filtering.

        Args:
            config (dict): Configuration dictionary.
            search (str, optional): Search string.
            offset (int, optional): Offset for pagination.
            limit (int, optional): Limit for pagination.

        Returns:
            list[Message]: List of message objects.
        """
        raise NotImplementedError

    @abstractmethod
    async def adelete_message(self, config: dict[str, Any], message_id: str | int) -> Any | None:
        """
        Delete a specific message asynchronously.

        Args:
            config (dict): Configuration dictionary.
            message_id (str|int): Message identifier.

        Returns:
            Any | None: Implementation-defined result.
        """
        raise NotImplementedError

    # -------------------------
    # Message methods sync
    # -------------------------
    def put_messages(
        self,
        config: dict[str, Any],
        messages: list[Message],
        metadata: dict[str, Any] | None = None,
    ) -> Any:
        """
        Store messages synchronously.

        Args:
            config (dict): Configuration dictionary.
            messages (list[Message]): List of messages to store.
            metadata (dict, optional): Additional metadata.

        Returns:
            Any: Implementation-defined result.
        """
        return run_coroutine(self.aput_messages(config, messages, metadata))

    def get_message(self, config: dict[str, Any], message_id: str | int) -> Message:
        """
        Retrieve a specific message synchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            Message: Retrieved message object.
        """
        return run_coroutine(self.aget_message(config, message_id))

    def list_messages(
        self,
        config: dict[str, Any],
        search: str | None = None,
        offset: int | None = None,
        limit: int | None = None,
    ) -> list[Message]:
        """
        List messages synchronously with optional filtering.

        Args:
            config (dict): Configuration dictionary.
            search (str, optional): Search string.
            offset (int, optional): Offset for pagination.
            limit (int, optional): Limit for pagination.

        Returns:
            list[Message]: List of message objects.
        """
        return run_coroutine(self.alist_messages(config, search, offset, limit))

    def delete_message(self, config: dict[str, Any], message_id: str | int) -> Any | None:
        """
        Delete a specific message synchronously.

        Args:
            config (dict): Configuration dictionary.
            message_id (str|int): Message identifier.

        Returns:
            Any | None: Implementation-defined result.
        """
        return run_coroutine(self.adelete_message(config, message_id))

    # -------------------------
    # Thread methods async
    # -------------------------
    @abstractmethod
    async def aput_thread(
        self,
        config: dict[str, Any],
        thread_info: ThreadInfo,
    ) -> Any | None:
        """
        Store thread info asynchronously.

        Args:
            config (dict): Configuration dictionary.
            thread_info (ThreadInfo): Thread information object.

        Returns:
            Any | None: Implementation-defined result.
        """
        raise NotImplementedError

    @abstractmethod
    async def aget_thread(
        self,
        config: dict[str, Any],
    ) -> ThreadInfo | None:
        """
        Retrieve thread info asynchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            ThreadInfo | None: Thread information object or None.
        """
        raise NotImplementedError

    @abstractmethod
    async def alist_threads(
        self,
        config: dict[str, Any],
        search: str | None = None,
        offset: int | None = None,
        limit: int | None = None,
    ) -> list[ThreadInfo]:
        """
        List threads asynchronously with optional filtering.

        Args:
            config (dict): Configuration dictionary.
            search (str, optional): Search string.
            offset (int, optional): Offset for pagination.
            limit (int, optional): Limit for pagination.

        Returns:
            list[ThreadInfo]: List of thread information objects.
        """
        raise NotImplementedError

    @abstractmethod
    async def aclean_thread(self, config: dict[str, Any]) -> Any | None:
        """
        Clean/delete thread asynchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            Any | None: Implementation-defined result.
        """
        raise NotImplementedError

    # -------------------------
    # Thread methods sync
    # -------------------------
    def put_thread(self, config: dict[str, Any], thread_info: ThreadInfo) -> Any | None:
        """
        Store thread info synchronously.

        Args:
            config (dict): Configuration dictionary.
            thread_info (ThreadInfo): Thread information object.

        Returns:
            Any | None: Implementation-defined result.
        """
        return run_coroutine(self.aput_thread(config, thread_info))

    def get_thread(self, config: dict[str, Any]) -> ThreadInfo | None:
        """
        Retrieve thread info synchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            ThreadInfo | None: Thread information object or None.
        """
        return run_coroutine(self.aget_thread(config))

    def list_threads(
        self,
        config: dict[str, Any],
        search: str | None = None,
        offset: int | None = None,
        limit: int | None = None,
    ) -> list[ThreadInfo]:
        """
        List threads synchronously with optional filtering.

        Args:
            config (dict): Configuration dictionary.
            search (str, optional): Search string.
            offset (int, optional): Offset for pagination.
            limit (int, optional): Limit for pagination.

        Returns:
            list[ThreadInfo]: List of thread information objects.
        """
        return run_coroutine(self.alist_threads(config, search, offset, limit))

    def clean_thread(self, config: dict[str, Any]) -> Any | None:
        """
        Clean/delete thread synchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            Any | None: Implementation-defined result.
        """
        return run_coroutine(self.aclean_thread(config))

    # -------------------------
    # Clean Resources
    # -------------------------
    def release(self) -> Any | None:
        """
        Release resources synchronously.

        Returns:
            Any | None: Implementation-defined result.
        """
        return run_coroutine(self.arelease())

    @abstractmethod
    async def arelease(self) -> Any | None:
        """
        Release resources asynchronously.

        Returns:
            Any | None: Implementation-defined result.
        """
        raise NotImplementedError
Functions
aclean_thread abstractmethod async
aclean_thread(config)

Clean/delete thread asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
396
397
398
399
400
401
402
403
404
405
406
407
@abstractmethod
async def aclean_thread(self, config: dict[str, Any]) -> Any | None:
    """
    Clean/delete thread asynchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        Any | None: Implementation-defined result.
    """
    raise NotImplementedError
aclear_state abstractmethod async
aclear_state(config)

Clear agent state asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Name Type Description
Any Any

Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
@abstractmethod
async def aclear_state(self, config: dict[str, Any]) -> Any:
    """
    Clear agent state asynchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        Any: Implementation-defined result.
    """
    raise NotImplementedError
adelete_message abstractmethod async
adelete_message(config, message_id)

Delete a specific message asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
message_id str | int

Message identifier.

required

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
255
256
257
258
259
260
261
262
263
264
265
266
267
@abstractmethod
async def adelete_message(self, config: dict[str, Any], message_id: str | int) -> Any | None:
    """
    Delete a specific message asynchronously.

    Args:
        config (dict): Configuration dictionary.
        message_id (str|int): Message identifier.

    Returns:
        Any | None: Implementation-defined result.
    """
    raise NotImplementedError
aget_message abstractmethod async
aget_message(config, message_id)

Retrieve a specific message asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
message_id str | int

Message identifier.

required

Returns:

Name Type Description
Message Message

Retrieved message object.

Source code in pyagenity/checkpointer/base_checkpointer.py
219
220
221
222
223
224
225
226
227
228
229
230
231
@abstractmethod
async def aget_message(self, config: dict[str, Any], message_id: str | int) -> Message:
    """
    Retrieve a specific message asynchronously.

    Args:
        config (dict): Configuration dictionary.
        message_id (str|int): Message identifier.

    Returns:
        Message: Retrieved message object.
    """
    raise NotImplementedError
aget_state abstractmethod async
aget_state(config)

Retrieve agent state asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
StateT | None

StateT | None: Retrieved state or None.

Source code in pyagenity/checkpointer/base_checkpointer.py
78
79
80
81
82
83
84
85
86
87
88
89
@abstractmethod
async def aget_state(self, config: dict[str, Any]) -> StateT | None:
    """
    Retrieve agent state asynchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        StateT | None: Retrieved state or None.
    """
    raise NotImplementedError
aget_state_cache abstractmethod async
aget_state_cache(config)

Retrieve agent state from cache asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
StateT | None

StateT | None: Cached state or None.

Source code in pyagenity/checkpointer/base_checkpointer.py
118
119
120
121
122
123
124
125
126
127
128
129
@abstractmethod
async def aget_state_cache(self, config: dict[str, Any]) -> StateT | None:
    """
    Retrieve agent state from cache asynchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        StateT | None: Cached state or None.
    """
    raise NotImplementedError
aget_thread abstractmethod async
aget_thread(config)

Retrieve thread info asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
ThreadInfo | None

ThreadInfo | None: Thread information object or None.

Source code in pyagenity/checkpointer/base_checkpointer.py
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
@abstractmethod
async def aget_thread(
    self,
    config: dict[str, Any],
) -> ThreadInfo | None:
    """
    Retrieve thread info asynchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        ThreadInfo | None: Thread information object or None.
    """
    raise NotImplementedError
alist_messages abstractmethod async
alist_messages(config, search=None, offset=None, limit=None)

List messages asynchronously with optional filtering.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
search str

Search string.

None
offset int

Offset for pagination.

None
limit int

Limit for pagination.

None

Returns:

Type Description
list[Message]

list[Message]: List of message objects.

Source code in pyagenity/checkpointer/base_checkpointer.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
@abstractmethod
async def alist_messages(
    self,
    config: dict[str, Any],
    search: str | None = None,
    offset: int | None = None,
    limit: int | None = None,
) -> list[Message]:
    """
    List messages asynchronously with optional filtering.

    Args:
        config (dict): Configuration dictionary.
        search (str, optional): Search string.
        offset (int, optional): Offset for pagination.
        limit (int, optional): Limit for pagination.

    Returns:
        list[Message]: List of message objects.
    """
    raise NotImplementedError
alist_threads abstractmethod async
alist_threads(config, search=None, offset=None, limit=None)

List threads asynchronously with optional filtering.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
search str

Search string.

None
offset int

Offset for pagination.

None
limit int

Limit for pagination.

None

Returns:

Type Description
list[ThreadInfo]

list[ThreadInfo]: List of thread information objects.

Source code in pyagenity/checkpointer/base_checkpointer.py
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
@abstractmethod
async def alist_threads(
    self,
    config: dict[str, Any],
    search: str | None = None,
    offset: int | None = None,
    limit: int | None = None,
) -> list[ThreadInfo]:
    """
    List threads asynchronously with optional filtering.

    Args:
        config (dict): Configuration dictionary.
        search (str, optional): Search string.
        offset (int, optional): Offset for pagination.
        limit (int, optional): Limit for pagination.

    Returns:
        list[ThreadInfo]: List of thread information objects.
    """
    raise NotImplementedError
aput_messages abstractmethod async
aput_messages(config, messages, metadata=None)

Store messages asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
messages list[Message]

List of messages to store.

required
metadata dict

Additional metadata.

None

Returns:

Name Type Description
Any Any

Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
@abstractmethod
async def aput_messages(
    self,
    config: dict[str, Any],
    messages: list[Message],
    metadata: dict[str, Any] | None = None,
) -> Any:
    """
    Store messages asynchronously.

    Args:
        config (dict): Configuration dictionary.
        messages (list[Message]): List of messages to store.
        metadata (dict, optional): Additional metadata.

    Returns:
        Any: Implementation-defined result.
    """
    raise NotImplementedError
aput_state abstractmethod async
aput_state(config, state)

Store agent state asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
state StateT

State object to store.

required

Returns:

Name Type Description
StateT StateT

The stored state object.

Source code in pyagenity/checkpointer/base_checkpointer.py
64
65
66
67
68
69
70
71
72
73
74
75
76
@abstractmethod
async def aput_state(self, config: dict[str, Any], state: StateT) -> StateT:
    """
    Store agent state asynchronously.

    Args:
        config (dict): Configuration dictionary.
        state (StateT): State object to store.

    Returns:
        StateT: The stored state object.
    """
    raise NotImplementedError
aput_state_cache abstractmethod async
aput_state_cache(config, state)

Store agent state in cache asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
state StateT

State object to cache.

required

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
104
105
106
107
108
109
110
111
112
113
114
115
116
@abstractmethod
async def aput_state_cache(self, config: dict[str, Any], state: StateT) -> Any | None:
    """
    Store agent state in cache asynchronously.

    Args:
        config (dict): Configuration dictionary.
        state (StateT): State object to cache.

    Returns:
        Any | None: Implementation-defined result.
    """
    raise NotImplementedError
aput_thread abstractmethod async
aput_thread(config, thread_info)

Store thread info asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
thread_info ThreadInfo

Thread information object.

required

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
@abstractmethod
async def aput_thread(
    self,
    config: dict[str, Any],
    thread_info: ThreadInfo,
) -> Any | None:
    """
    Store thread info asynchronously.

    Args:
        config (dict): Configuration dictionary.
        thread_info (ThreadInfo): Thread information object.

    Returns:
        Any | None: Implementation-defined result.
    """
    raise NotImplementedError
arelease abstractmethod async
arelease()

Release resources asynchronously.

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
482
483
484
485
486
487
488
489
490
@abstractmethod
async def arelease(self) -> Any | None:
    """
    Release resources asynchronously.

    Returns:
        Any | None: Implementation-defined result.
    """
    raise NotImplementedError
asetup abstractmethod async
asetup()

Asynchronous setup method for checkpointer.

Returns:

Name Type Description
Any Any

Implementation-defined setup result.

Source code in pyagenity/checkpointer/base_checkpointer.py
51
52
53
54
55
56
57
58
59
@abstractmethod
async def asetup(self) -> Any:
    """
    Asynchronous setup method for checkpointer.

    Returns:
        Any: Implementation-defined setup result.
    """
    raise NotImplementedError
clean_thread
clean_thread(config)

Clean/delete thread synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
458
459
460
461
462
463
464
465
466
467
468
def clean_thread(self, config: dict[str, Any]) -> Any | None:
    """
    Clean/delete thread synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        Any | None: Implementation-defined result.
    """
    return run_coroutine(self.aclean_thread(config))
clear_state
clear_state(config)

Clear agent state synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Name Type Description
Any Any

Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
159
160
161
162
163
164
165
166
167
168
169
def clear_state(self, config: dict[str, Any]) -> Any:
    """
    Clear agent state synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        Any: Implementation-defined result.
    """
    return run_coroutine(self.aclear_state(config))
delete_message
delete_message(config, message_id)

Delete a specific message synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
message_id str | int

Message identifier.

required

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
324
325
326
327
328
329
330
331
332
333
334
335
def delete_message(self, config: dict[str, Any], message_id: str | int) -> Any | None:
    """
    Delete a specific message synchronously.

    Args:
        config (dict): Configuration dictionary.
        message_id (str|int): Message identifier.

    Returns:
        Any | None: Implementation-defined result.
    """
    return run_coroutine(self.adelete_message(config, message_id))
get_message
get_message(config, message_id)

Retrieve a specific message synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Name Type Description
Message Message

Retrieved message object.

Source code in pyagenity/checkpointer/base_checkpointer.py
291
292
293
294
295
296
297
298
299
300
301
def get_message(self, config: dict[str, Any], message_id: str | int) -> Message:
    """
    Retrieve a specific message synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        Message: Retrieved message object.
    """
    return run_coroutine(self.aget_message(config, message_id))
get_state
get_state(config)

Retrieve agent state synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
StateT | None

StateT | None: Retrieved state or None.

Source code in pyagenity/checkpointer/base_checkpointer.py
147
148
149
150
151
152
153
154
155
156
157
def get_state(self, config: dict[str, Any]) -> StateT | None:
    """
    Retrieve agent state synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        StateT | None: Retrieved state or None.
    """
    return run_coroutine(self.aget_state(config))
get_state_cache
get_state_cache(config)

Retrieve agent state from cache synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
StateT | None

StateT | None: Cached state or None.

Source code in pyagenity/checkpointer/base_checkpointer.py
184
185
186
187
188
189
190
191
192
193
194
def get_state_cache(self, config: dict[str, Any]) -> StateT | None:
    """
    Retrieve agent state from cache synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        StateT | None: Cached state or None.
    """
    return run_coroutine(self.aget_state_cache(config))
get_thread
get_thread(config)

Retrieve thread info synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
ThreadInfo | None

ThreadInfo | None: Thread information object or None.

Source code in pyagenity/checkpointer/base_checkpointer.py
425
426
427
428
429
430
431
432
433
434
435
def get_thread(self, config: dict[str, Any]) -> ThreadInfo | None:
    """
    Retrieve thread info synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        ThreadInfo | None: Thread information object or None.
    """
    return run_coroutine(self.aget_thread(config))
list_messages
list_messages(config, search=None, offset=None, limit=None)

List messages synchronously with optional filtering.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
search str

Search string.

None
offset int

Offset for pagination.

None
limit int

Limit for pagination.

None

Returns:

Type Description
list[Message]

list[Message]: List of message objects.

Source code in pyagenity/checkpointer/base_checkpointer.py
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
def list_messages(
    self,
    config: dict[str, Any],
    search: str | None = None,
    offset: int | None = None,
    limit: int | None = None,
) -> list[Message]:
    """
    List messages synchronously with optional filtering.

    Args:
        config (dict): Configuration dictionary.
        search (str, optional): Search string.
        offset (int, optional): Offset for pagination.
        limit (int, optional): Limit for pagination.

    Returns:
        list[Message]: List of message objects.
    """
    return run_coroutine(self.alist_messages(config, search, offset, limit))
list_threads
list_threads(config, search=None, offset=None, limit=None)

List threads synchronously with optional filtering.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
search str

Search string.

None
offset int

Offset for pagination.

None
limit int

Limit for pagination.

None

Returns:

Type Description
list[ThreadInfo]

list[ThreadInfo]: List of thread information objects.

Source code in pyagenity/checkpointer/base_checkpointer.py
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
def list_threads(
    self,
    config: dict[str, Any],
    search: str | None = None,
    offset: int | None = None,
    limit: int | None = None,
) -> list[ThreadInfo]:
    """
    List threads synchronously with optional filtering.

    Args:
        config (dict): Configuration dictionary.
        search (str, optional): Search string.
        offset (int, optional): Offset for pagination.
        limit (int, optional): Limit for pagination.

    Returns:
        list[ThreadInfo]: List of thread information objects.
    """
    return run_coroutine(self.alist_threads(config, search, offset, limit))
put_messages
put_messages(config, messages, metadata=None)

Store messages synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
messages list[Message]

List of messages to store.

required
metadata dict

Additional metadata.

None

Returns:

Name Type Description
Any Any

Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
def put_messages(
    self,
    config: dict[str, Any],
    messages: list[Message],
    metadata: dict[str, Any] | None = None,
) -> Any:
    """
    Store messages synchronously.

    Args:
        config (dict): Configuration dictionary.
        messages (list[Message]): List of messages to store.
        metadata (dict, optional): Additional metadata.

    Returns:
        Any: Implementation-defined result.
    """
    return run_coroutine(self.aput_messages(config, messages, metadata))
put_state
put_state(config, state)

Store agent state synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
state StateT

State object to store.

required

Returns:

Name Type Description
StateT StateT

The stored state object.

Source code in pyagenity/checkpointer/base_checkpointer.py
134
135
136
137
138
139
140
141
142
143
144
145
def put_state(self, config: dict[str, Any], state: StateT) -> StateT:
    """
    Store agent state synchronously.

    Args:
        config (dict): Configuration dictionary.
        state (StateT): State object to store.

    Returns:
        StateT: The stored state object.
    """
    return run_coroutine(self.aput_state(config, state))
put_state_cache
put_state_cache(config, state)

Store agent state in cache synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
state StateT

State object to cache.

required

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
171
172
173
174
175
176
177
178
179
180
181
182
def put_state_cache(self, config: dict[str, Any], state: StateT) -> Any | None:
    """
    Store agent state in cache synchronously.

    Args:
        config (dict): Configuration dictionary.
        state (StateT): State object to cache.

    Returns:
        Any | None: Implementation-defined result.
    """
    return run_coroutine(self.aput_state_cache(config, state))
put_thread
put_thread(config, thread_info)

Store thread info synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
thread_info ThreadInfo

Thread information object.

required

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
412
413
414
415
416
417
418
419
420
421
422
423
def put_thread(self, config: dict[str, Any], thread_info: ThreadInfo) -> Any | None:
    """
    Store thread info synchronously.

    Args:
        config (dict): Configuration dictionary.
        thread_info (ThreadInfo): Thread information object.

    Returns:
        Any | None: Implementation-defined result.
    """
    return run_coroutine(self.aput_thread(config, thread_info))
release
release()

Release resources synchronously.

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
473
474
475
476
477
478
479
480
def release(self) -> Any | None:
    """
    Release resources synchronously.

    Returns:
        Any | None: Implementation-defined result.
    """
    return run_coroutine(self.arelease())
setup
setup()

Synchronous setup method for checkpointer.

Returns:

Name Type Description
Any Any

Implementation-defined setup result.

Source code in pyagenity/checkpointer/base_checkpointer.py
42
43
44
45
46
47
48
49
def setup(self) -> Any:
    """
    Synchronous setup method for checkpointer.

    Returns:
        Any: Implementation-defined setup result.
    """
    return run_coroutine(self.asetup())
Functions
in_memory_checkpointer

Classes:

Name Description
InMemoryCheckpointer

In-memory implementation of BaseCheckpointer.

Attributes:

Name Type Description
StateT
logger
Attributes
StateT module-attribute
StateT = TypeVar('StateT', bound='AgentState')
logger module-attribute
logger = getLogger(__name__)
Classes
InMemoryCheckpointer

Bases: BaseCheckpointer[StateT]

In-memory implementation of BaseCheckpointer.

Stores all agent state, messages, and thread info in memory using Python dictionaries. Data is lost when the process ends. Designed for testing and ephemeral use cases. Async-first design using asyncio locks for concurrent access.

Attributes:

Name Type Description
_states dict

Stores agent states by thread key.

_state_cache dict

Stores cached agent states by thread key.

_messages dict

Stores messages by thread key.

_message_metadata dict

Stores message metadata by thread key.

_threads dict

Stores thread info by thread key.

_state_lock Lock

Lock for state operations.

_messages_lock Lock

Lock for message operations.

_threads_lock Lock

Lock for thread operations.

Methods:

Name Description
__init__

Initialize all in-memory storage and locks.

aclean_thread

Clean/delete thread asynchronously.

aclear_state

Clear state asynchronously.

adelete_message

Delete a specific message asynchronously.

aget_message

Retrieve a specific message asynchronously.

aget_state

Retrieve state asynchronously.

aget_state_cache

Retrieve state cache asynchronously.

aget_thread

Retrieve thread info asynchronously.

alist_messages

List messages asynchronously with optional filtering.

alist_threads

List all threads asynchronously with optional filtering.

aput_messages

Store messages asynchronously.

aput_state

Store state asynchronously.

aput_state_cache

Store state cache asynchronously.

aput_thread

Store thread info asynchronously.

arelease

Release resources asynchronously.

asetup

Asynchronous setup method. No setup required for in-memory checkpointer.

clean_thread

Clean/delete thread synchronously.

clear_state

Clear state synchronously.

delete_message

Delete a specific message synchronously.

get_message

Retrieve a specific message synchronously.

get_state

Retrieve state synchronously.

get_state_cache

Retrieve state cache synchronously.

get_thread

Retrieve thread info synchronously.

list_messages

List messages synchronously with optional filtering.

list_threads

List all threads synchronously with optional filtering.

put_messages

Store messages synchronously.

put_state

Store state synchronously.

put_state_cache

Store state cache synchronously.

put_thread

Store thread info synchronously.

release

Release resources synchronously.

setup

Synchronous setup method. No setup required for in-memory checkpointer.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
class InMemoryCheckpointer[StateT: AgentState](BaseCheckpointer[StateT]):
    """
    In-memory implementation of BaseCheckpointer.

    Stores all agent state, messages, and thread info in memory using Python dictionaries.
    Data is lost when the process ends. Designed for testing and ephemeral use cases.
    Async-first design using asyncio locks for concurrent access.

    Args:
        None

    Attributes:
        _states (dict): Stores agent states by thread key.
        _state_cache (dict): Stores cached agent states by thread key.
        _messages (dict): Stores messages by thread key.
        _message_metadata (dict): Stores message metadata by thread key.
        _threads (dict): Stores thread info by thread key.
        _state_lock (asyncio.Lock): Lock for state operations.
        _messages_lock (asyncio.Lock): Lock for message operations.
        _threads_lock (asyncio.Lock): Lock for thread operations.
    """

    def __init__(self):
        """
        Initialize all in-memory storage and locks.
        """
        # State storage
        self._states: dict[str, StateT] = {}
        self._state_cache: dict[str, StateT] = {}

        # Message storage - organized by config key
        self._messages: dict[str, list[Message]] = defaultdict(list)
        self._message_metadata: dict[str, dict[str, Any]] = {}

        # Thread storage
        self._threads: dict[str, dict[str, Any]] = {}

        # Async locks for concurrent access
        self._state_lock = asyncio.Lock()
        self._messages_lock = asyncio.Lock()
        self._threads_lock = asyncio.Lock()

    def setup(self) -> Any:
        """
        Synchronous setup method. No setup required for in-memory checkpointer.
        """
        logger.debug("InMemoryCheckpointer setup not required")

    async def asetup(self) -> Any:
        """
        Asynchronous setup method. No setup required for in-memory checkpointer.
        """
        logger.debug("InMemoryCheckpointer async setup not required")

    def _get_config_key(self, config: dict[str, Any]) -> str:
        """
        Generate a string key from config dict for storage indexing.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            str: Key for indexing storage.
        """
        """Generate a string key from config dict for storage indexing."""
        # Sort keys for consistent hashing
        thread_id = config.get("thread_id", "")
        return str(thread_id)

    # -------------------------
    # State methods Async
    # -------------------------
    async def aput_state(self, config: dict[str, Any], state: StateT) -> StateT:
        """
        Store state asynchronously.

        Args:
            config (dict): Configuration dictionary.
            state (StateT): State object to store.

        Returns:
            StateT: The stored state object.
        """
        """Store state asynchronously."""
        key = self._get_config_key(config)
        async with self._state_lock:
            self._states[key] = state
            logger.debug(f"Stored state for key: {key}")
            return state

    async def aget_state(self, config: dict[str, Any]) -> StateT | None:
        """
        Retrieve state asynchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            StateT | None: Retrieved state or None.
        """
        """Retrieve state asynchronously."""
        key = self._get_config_key(config)
        async with self._state_lock:
            state = self._states.get(key)
            logger.debug(f"Retrieved state for key: {key}, found: {state is not None}")
            return state

    async def aclear_state(self, config: dict[str, Any]) -> bool:
        """
        Clear state asynchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            bool: True if cleared.
        """
        """Clear state asynchronously."""
        key = self._get_config_key(config)
        async with self._state_lock:
            if key in self._states:
                del self._states[key]
                logger.debug(f"Cleared state for key: {key}")
            return True

    async def aput_state_cache(self, config: dict[str, Any], state: StateT) -> StateT:
        """
        Store state cache asynchronously.

        Args:
            config (dict): Configuration dictionary.
            state (StateT): State object to cache.

        Returns:
            StateT: The cached state object.
        """
        """Store state cache asynchronously."""
        key = self._get_config_key(config)
        async with self._state_lock:
            self._state_cache[key] = state
            logger.debug(f"Stored state cache for key: {key}")
            return state

    async def aget_state_cache(self, config: dict[str, Any]) -> StateT | None:
        """
        Retrieve state cache asynchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            StateT | None: Cached state or None.
        """
        """Retrieve state cache asynchronously."""
        key = self._get_config_key(config)
        async with self._state_lock:
            cache = self._state_cache.get(key)
            logger.debug(f"Retrieved state cache for key: {key}, found: {cache is not None}")
            return cache

    # -------------------------
    # State methods Sync
    # -------------------------
    def put_state(self, config: dict[str, Any], state: StateT) -> StateT:
        """
        Store state synchronously.

        Args:
            config (dict): Configuration dictionary.
            state (StateT): State object to store.

        Returns:
            StateT: The stored state object.
        """
        """Store state synchronously."""
        key = self._get_config_key(config)
        # For sync methods, we'll use a simple approach without locks
        # In a real async-first system, sync methods might not be used
        self._states[key] = state
        logger.debug(f"Stored state for key: {key}")
        return state

    def get_state(self, config: dict[str, Any]) -> StateT | None:
        """
        Retrieve state synchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            StateT | None: Retrieved state or None.
        """
        """Retrieve state synchronously."""
        key = self._get_config_key(config)
        state = self._states.get(key)
        logger.debug(f"Retrieved state for key: {key}, found: {state is not None}")
        return state

    def clear_state(self, config: dict[str, Any]) -> bool:
        """
        Clear state synchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            bool: True if cleared.
        """
        """Clear state synchronously."""
        key = self._get_config_key(config)
        if key in self._states:
            del self._states[key]
            logger.debug(f"Cleared state for key: {key}")
        return True

    def put_state_cache(self, config: dict[str, Any], state: StateT) -> StateT:
        """
        Store state cache synchronously.

        Args:
            config (dict): Configuration dictionary.
            state (StateT): State object to cache.

        Returns:
            StateT: The cached state object.
        """
        """Store state cache synchronously."""
        key = self._get_config_key(config)
        self._state_cache[key] = state
        logger.debug(f"Stored state cache for key: {key}")
        return state

    def get_state_cache(self, config: dict[str, Any]) -> StateT | None:
        """
        Retrieve state cache synchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            StateT | None: Cached state or None.
        """
        """Retrieve state cache synchronously."""
        key = self._get_config_key(config)
        cache = self._state_cache.get(key)
        logger.debug(f"Retrieved state cache for key: {key}, found: {cache is not None}")
        return cache

    # -------------------------
    # Message methods async
    # -------------------------
    async def aput_messages(
        self,
        config: dict[str, Any],
        messages: list[Message],
        metadata: dict[str, Any] | None = None,
    ) -> bool:
        """
        Store messages asynchronously.

        Args:
            config (dict): Configuration dictionary.
            messages (list[Message]): List of messages to store.
            metadata (dict, optional): Additional metadata.

        Returns:
            bool: True if stored.
        """
        key = self._get_config_key(config)
        async with self._messages_lock:
            self._messages[key].extend(messages)
            if metadata:
                self._message_metadata[key] = metadata
            logger.debug(f"Stored {len(messages)} messages for key: {key}")
            return True

    async def aget_message(self, config: dict[str, Any], message_id: str | int) -> Message:
        """
        Retrieve a specific message asynchronously.

        Args:
            config (dict): Configuration dictionary.
            message_id (str|int): Message identifier.

        Returns:
            Message: Retrieved message object.

        Raises:
            IndexError: If message not found.
        """
        """Retrieve a specific message asynchronously."""
        key = self._get_config_key(config)
        async with self._messages_lock:
            messages = self._messages.get(key, [])
            for msg in messages:
                if msg.message_id == message_id:
                    return msg
            raise IndexError(f"Message with ID {message_id} not found for config key: {key}")

    async def alist_messages(
        self,
        config: dict[str, Any],
        search: str | None = None,
        offset: int | None = None,
        limit: int | None = None,
    ) -> list[Message]:
        """
        List messages asynchronously with optional filtering.

        Args:
            config (dict): Configuration dictionary.
            search (str, optional): Search string.
            offset (int, optional): Offset for pagination.
            limit (int, optional): Limit for pagination.

        Returns:
            list[Message]: List of message objects.
        """
        key = self._get_config_key(config)
        async with self._messages_lock:
            messages = self._messages.get(key, [])

            # Apply search filter if provided
            if search:
                # Simple string search in message content
                messages = [
                    msg
                    for msg in messages
                    if hasattr(msg, "content") and search.lower() in str(msg.content).lower()
                ]

            # Apply offset and limit
            start = offset or 0
            end = (start + limit) if limit else None
            return messages[start:end]

    async def adelete_message(self, config: dict[str, Any], message_id: str | int) -> bool:
        """
        Delete a specific message asynchronously.

        Args:
            config (dict): Configuration dictionary.
            message_id (str|int): Message identifier.

        Returns:
            bool: True if deleted.

        Raises:
            IndexError: If message not found.
        """
        """Delete a specific message asynchronously."""
        key = self._get_config_key(config)
        async with self._messages_lock:
            messages = self._messages.get(key, [])
            for msg in messages:
                if msg.message_id == message_id:
                    messages.remove(msg)
                    logger.debug(f"Deleted message with ID {message_id} for key: {key}")
                    return True
            raise IndexError(f"Message with ID {message_id} not found for config key: {key}")

    # -------------------------
    # Message methods sync
    # -------------------------
    def put_messages(
        self,
        config: dict[str, Any],
        messages: list[Message],
        metadata: dict[str, Any] | None = None,
    ) -> bool:
        """
        Store messages synchronously.

        Args:
            config (dict): Configuration dictionary.
            messages (list[Message]): List of messages to store.
            metadata (dict, optional): Additional metadata.

        Returns:
            bool: True if stored.
        """
        key = self._get_config_key(config)
        self._messages[key].extend(messages)
        if metadata:
            self._message_metadata[key] = metadata

        logger.debug(f"Stored {len(messages)} messages for key: {key}")
        return True

    def get_message(self, config: dict[str, Any], message_id: str | int) -> Message:
        """
        Retrieve a specific message synchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            Message: Latest message object.

        Raises:
            IndexError: If no messages found.
        """
        """Retrieve the latest message synchronously."""
        key = self._get_config_key(config)
        messages = self._messages.get(key, [])
        for msg in messages:
            if msg.message_id == message_id:
                return msg
        raise IndexError(f"Message with ID {message_id} not found for config key: {key}")

    def list_messages(
        self,
        config: dict[str, Any],
        search: str | None = None,
        offset: int | None = None,
        limit: int | None = None,
    ) -> list[Message]:
        """
        List messages synchronously with optional filtering.

        Args:
            config (dict): Configuration dictionary.
            search (str, optional): Search string.
            offset (int, optional): Offset for pagination.
            limit (int, optional): Limit for pagination.

        Returns:
            list[Message]: List of message objects.
        """
        key = self._get_config_key(config)
        messages = self._messages.get(key, [])

        # Apply search filter if provided
        if search:
            messages = [
                msg
                for msg in messages
                if hasattr(msg, "content") and search.lower() in str(msg.content).lower()
            ]

        # Apply offset and limit
        start = offset or 0
        end = (start + limit) if limit else None
        return messages[start:end]

    def delete_message(self, config: dict[str, Any], message_id: str | int) -> bool:
        """
        Delete a specific message synchronously.

        Args:
            config (dict): Configuration dictionary.
            message_id (str|int): Message identifier.

        Returns:
            bool: True if deleted.

        Raises:
            IndexError: If message not found.
        """
        """Delete a specific message synchronously."""
        key = self._get_config_key(config)
        messages = self._messages.get(key, [])
        for msg in messages:
            if msg.message_id == message_id:
                messages.remove(msg)
                logger.debug(f"Deleted message with ID {message_id} for key: {key}")
                return True
        raise IndexError(f"Message with ID {message_id} not found for config key: {key}")

    # -------------------------
    # Thread methods async
    # -------------------------
    async def aput_thread(
        self,
        config: dict[str, Any],
        thread_info: ThreadInfo,
    ) -> bool:
        """
        Store thread info asynchronously.

        Args:
            config (dict): Configuration dictionary.
            thread_info (ThreadInfo): Thread information object.

        Returns:
            bool: True if stored.
        """
        key = self._get_config_key(config)
        async with self._threads_lock:
            self._threads[key] = thread_info.model_dump()
            logger.debug(f"Stored thread info for key: {key}")
            return True

    async def aget_thread(
        self,
        config: dict[str, Any],
    ) -> ThreadInfo | None:
        """
        Retrieve thread info asynchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            ThreadInfo | None: Thread information object or None.
        """
        key = self._get_config_key(config)
        async with self._threads_lock:
            thread = self._threads.get(key)
            logger.debug(f"Retrieved thread for key: {key}, found: {thread is not None}")
            return ThreadInfo.model_validate(thread) if thread else None

    async def alist_threads(
        self,
        config: dict[str, Any],
        search: str | None = None,
        offset: int | None = None,
        limit: int | None = None,
    ) -> list[ThreadInfo]:
        """
        List all threads asynchronously with optional filtering.

        Args:
            config (dict): Configuration dictionary.
            search (str, optional): Search string.
            offset (int, optional): Offset for pagination.
            limit (int, optional): Limit for pagination.

        Returns:
            list[ThreadInfo]: List of thread information objects.
        """
        async with self._threads_lock:
            threads = list(self._threads.values())

            # Apply search filter if provided
            if search:
                threads = [
                    thread
                    for thread in threads
                    if any(search.lower() in str(value).lower() for value in thread.values())
                ]

            # Apply offset and limit
            start = offset or 0
            end = (start + limit) if limit else None
            return [ThreadInfo.model_validate(thread) for thread in threads[start:end]]

    async def aclean_thread(self, config: dict[str, Any]) -> bool:
        """
        Clean/delete thread asynchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            bool: True if cleaned.
        """
        """Clean/delete thread asynchronously."""
        key = self._get_config_key(config)
        async with self._threads_lock:
            if key in self._threads:
                del self._threads[key]
                logger.debug(f"Cleaned thread for key: {key}")
                return True
        return False

    # -------------------------
    # Thread methods sync
    # -------------------------
    def put_thread(self, config: dict[str, Any], thread_info: ThreadInfo) -> bool:
        """
        Store thread info synchronously.

        Args:
            config (dict): Configuration dictionary.
            thread_info (ThreadInfo): Thread information object.

        Returns:
            bool: True if stored.
        """
        """Store thread info synchronously."""
        key = self._get_config_key(config)
        self._threads[key] = thread_info.model_dump()
        logger.debug(f"Stored thread info for key: {key}")
        return True

    def get_thread(self, config: dict[str, Any]) -> ThreadInfo | None:
        """
        Retrieve thread info synchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            ThreadInfo | None: Thread information object or None.
        """
        """Retrieve thread info synchronously."""
        key = self._get_config_key(config)
        thread = self._threads.get(key)
        logger.debug(f"Retrieved thread for key: {key}, found: {thread is not None}")
        return ThreadInfo.model_validate(thread) if thread else None

    def list_threads(
        self,
        config: dict[str, Any],
        search: str | None = None,
        offset: int | None = None,
        limit: int | None = None,
    ) -> list[ThreadInfo]:
        """
        List all threads synchronously with optional filtering.

        Args:
            config (dict): Configuration dictionary.
            search (str, optional): Search string.
            offset (int, optional): Offset for pagination.
            limit (int, optional): Limit for pagination.

        Returns:
            list[ThreadInfo]: List of thread information objects.
        """
        threads = list(self._threads.values())

        # Apply search filter if provided
        if search:
            threads = [
                thread
                for thread in threads
                if any(search.lower() in str(value).lower() for value in thread.values())
            ]

        # Apply offset and limit
        start = offset or 0
        end = (start + limit) if limit else None
        return [ThreadInfo.model_validate(thread) for thread in threads[start:end]]

    def clean_thread(self, config: dict[str, Any]) -> bool:
        """
        Clean/delete thread synchronously.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            bool: True if cleaned.
        """
        """Clean/delete thread synchronously."""
        key = self._get_config_key(config)
        if key in self._threads:
            del self._threads[key]
            logger.debug(f"Cleaned thread for key: {key}")
            return True
        return False

    # -------------------------
    # Clean Resources
    # -------------------------
    async def arelease(self) -> bool:
        """
        Release resources asynchronously.

        Returns:
            bool: True if released.
        """
        """Release resources asynchronously."""
        async with self._state_lock, self._messages_lock, self._threads_lock:
            self._states.clear()
            self._state_cache.clear()
            self._messages.clear()
            self._message_metadata.clear()
            self._threads.clear()
            logger.info("Released all in-memory resources")
            return True

    def release(self) -> bool:
        """
        Release resources synchronously.

        Returns:
            bool: True if released.
        """
        """Release resources synchronously."""
        self._states.clear()
        self._state_cache.clear()
        self._messages.clear()
        self._message_metadata.clear()
        self._threads.clear()
        logger.info("Released all in-memory resources")
        return True
Functions
__init__
__init__()

Initialize all in-memory storage and locks.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def __init__(self):
    """
    Initialize all in-memory storage and locks.
    """
    # State storage
    self._states: dict[str, StateT] = {}
    self._state_cache: dict[str, StateT] = {}

    # Message storage - organized by config key
    self._messages: dict[str, list[Message]] = defaultdict(list)
    self._message_metadata: dict[str, dict[str, Any]] = {}

    # Thread storage
    self._threads: dict[str, dict[str, Any]] = {}

    # Async locks for concurrent access
    self._state_lock = asyncio.Lock()
    self._messages_lock = asyncio.Lock()
    self._threads_lock = asyncio.Lock()
aclean_thread async
aclean_thread(config)

Clean/delete thread asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Name Type Description
bool bool

True if cleaned.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
async def aclean_thread(self, config: dict[str, Any]) -> bool:
    """
    Clean/delete thread asynchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        bool: True if cleaned.
    """
    """Clean/delete thread asynchronously."""
    key = self._get_config_key(config)
    async with self._threads_lock:
        if key in self._threads:
            del self._threads[key]
            logger.debug(f"Cleaned thread for key: {key}")
            return True
    return False
aclear_state async
aclear_state(config)

Clear state asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Name Type Description
bool bool

True if cleared.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
async def aclear_state(self, config: dict[str, Any]) -> bool:
    """
    Clear state asynchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        bool: True if cleared.
    """
    """Clear state asynchronously."""
    key = self._get_config_key(config)
    async with self._state_lock:
        if key in self._states:
            del self._states[key]
            logger.debug(f"Cleared state for key: {key}")
        return True
adelete_message async
adelete_message(config, message_id)

Delete a specific message asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
message_id str | int

Message identifier.

required

Returns:

Name Type Description
bool bool

True if deleted.

Raises:

Type Description
IndexError

If message not found.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
async def adelete_message(self, config: dict[str, Any], message_id: str | int) -> bool:
    """
    Delete a specific message asynchronously.

    Args:
        config (dict): Configuration dictionary.
        message_id (str|int): Message identifier.

    Returns:
        bool: True if deleted.

    Raises:
        IndexError: If message not found.
    """
    """Delete a specific message asynchronously."""
    key = self._get_config_key(config)
    async with self._messages_lock:
        messages = self._messages.get(key, [])
        for msg in messages:
            if msg.message_id == message_id:
                messages.remove(msg)
                logger.debug(f"Deleted message with ID {message_id} for key: {key}")
                return True
        raise IndexError(f"Message with ID {message_id} not found for config key: {key}")
aget_message async
aget_message(config, message_id)

Retrieve a specific message asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
message_id str | int

Message identifier.

required

Returns:

Name Type Description
Message Message

Retrieved message object.

Raises:

Type Description
IndexError

If message not found.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
async def aget_message(self, config: dict[str, Any], message_id: str | int) -> Message:
    """
    Retrieve a specific message asynchronously.

    Args:
        config (dict): Configuration dictionary.
        message_id (str|int): Message identifier.

    Returns:
        Message: Retrieved message object.

    Raises:
        IndexError: If message not found.
    """
    """Retrieve a specific message asynchronously."""
    key = self._get_config_key(config)
    async with self._messages_lock:
        messages = self._messages.get(key, [])
        for msg in messages:
            if msg.message_id == message_id:
                return msg
        raise IndexError(f"Message with ID {message_id} not found for config key: {key}")
aget_state async
aget_state(config)

Retrieve state asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
StateT | None

StateT | None: Retrieved state or None.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
async def aget_state(self, config: dict[str, Any]) -> StateT | None:
    """
    Retrieve state asynchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        StateT | None: Retrieved state or None.
    """
    """Retrieve state asynchronously."""
    key = self._get_config_key(config)
    async with self._state_lock:
        state = self._states.get(key)
        logger.debug(f"Retrieved state for key: {key}, found: {state is not None}")
        return state
aget_state_cache async
aget_state_cache(config)

Retrieve state cache asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
StateT | None

StateT | None: Cached state or None.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
async def aget_state_cache(self, config: dict[str, Any]) -> StateT | None:
    """
    Retrieve state cache asynchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        StateT | None: Cached state or None.
    """
    """Retrieve state cache asynchronously."""
    key = self._get_config_key(config)
    async with self._state_lock:
        cache = self._state_cache.get(key)
        logger.debug(f"Retrieved state cache for key: {key}, found: {cache is not None}")
        return cache
aget_thread async
aget_thread(config)

Retrieve thread info asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
ThreadInfo | None

ThreadInfo | None: Thread information object or None.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
async def aget_thread(
    self,
    config: dict[str, Any],
) -> ThreadInfo | None:
    """
    Retrieve thread info asynchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        ThreadInfo | None: Thread information object or None.
    """
    key = self._get_config_key(config)
    async with self._threads_lock:
        thread = self._threads.get(key)
        logger.debug(f"Retrieved thread for key: {key}, found: {thread is not None}")
        return ThreadInfo.model_validate(thread) if thread else None
alist_messages async
alist_messages(config, search=None, offset=None, limit=None)

List messages asynchronously with optional filtering.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
search str

Search string.

None
offset int

Offset for pagination.

None
limit int

Limit for pagination.

None

Returns:

Type Description
list[Message]

list[Message]: List of message objects.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
async def alist_messages(
    self,
    config: dict[str, Any],
    search: str | None = None,
    offset: int | None = None,
    limit: int | None = None,
) -> list[Message]:
    """
    List messages asynchronously with optional filtering.

    Args:
        config (dict): Configuration dictionary.
        search (str, optional): Search string.
        offset (int, optional): Offset for pagination.
        limit (int, optional): Limit for pagination.

    Returns:
        list[Message]: List of message objects.
    """
    key = self._get_config_key(config)
    async with self._messages_lock:
        messages = self._messages.get(key, [])

        # Apply search filter if provided
        if search:
            # Simple string search in message content
            messages = [
                msg
                for msg in messages
                if hasattr(msg, "content") and search.lower() in str(msg.content).lower()
            ]

        # Apply offset and limit
        start = offset or 0
        end = (start + limit) if limit else None
        return messages[start:end]
alist_threads async
alist_threads(config, search=None, offset=None, limit=None)

List all threads asynchronously with optional filtering.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
search str

Search string.

None
offset int

Offset for pagination.

None
limit int

Limit for pagination.

None

Returns:

Type Description
list[ThreadInfo]

list[ThreadInfo]: List of thread information objects.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
async def alist_threads(
    self,
    config: dict[str, Any],
    search: str | None = None,
    offset: int | None = None,
    limit: int | None = None,
) -> list[ThreadInfo]:
    """
    List all threads asynchronously with optional filtering.

    Args:
        config (dict): Configuration dictionary.
        search (str, optional): Search string.
        offset (int, optional): Offset for pagination.
        limit (int, optional): Limit for pagination.

    Returns:
        list[ThreadInfo]: List of thread information objects.
    """
    async with self._threads_lock:
        threads = list(self._threads.values())

        # Apply search filter if provided
        if search:
            threads = [
                thread
                for thread in threads
                if any(search.lower() in str(value).lower() for value in thread.values())
            ]

        # Apply offset and limit
        start = offset or 0
        end = (start + limit) if limit else None
        return [ThreadInfo.model_validate(thread) for thread in threads[start:end]]
aput_messages async
aput_messages(config, messages, metadata=None)

Store messages asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
messages list[Message]

List of messages to store.

required
metadata dict

Additional metadata.

None

Returns:

Name Type Description
bool bool

True if stored.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
async def aput_messages(
    self,
    config: dict[str, Any],
    messages: list[Message],
    metadata: dict[str, Any] | None = None,
) -> bool:
    """
    Store messages asynchronously.

    Args:
        config (dict): Configuration dictionary.
        messages (list[Message]): List of messages to store.
        metadata (dict, optional): Additional metadata.

    Returns:
        bool: True if stored.
    """
    key = self._get_config_key(config)
    async with self._messages_lock:
        self._messages[key].extend(messages)
        if metadata:
            self._message_metadata[key] = metadata
        logger.debug(f"Stored {len(messages)} messages for key: {key}")
        return True
aput_state async
aput_state(config, state)

Store state asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
state StateT

State object to store.

required

Returns:

Name Type Description
StateT StateT

The stored state object.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
async def aput_state(self, config: dict[str, Any], state: StateT) -> StateT:
    """
    Store state asynchronously.

    Args:
        config (dict): Configuration dictionary.
        state (StateT): State object to store.

    Returns:
        StateT: The stored state object.
    """
    """Store state asynchronously."""
    key = self._get_config_key(config)
    async with self._state_lock:
        self._states[key] = state
        logger.debug(f"Stored state for key: {key}")
        return state
aput_state_cache async
aput_state_cache(config, state)

Store state cache asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
state StateT

State object to cache.

required

Returns:

Name Type Description
StateT StateT

The cached state object.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
async def aput_state_cache(self, config: dict[str, Any], state: StateT) -> StateT:
    """
    Store state cache asynchronously.

    Args:
        config (dict): Configuration dictionary.
        state (StateT): State object to cache.

    Returns:
        StateT: The cached state object.
    """
    """Store state cache asynchronously."""
    key = self._get_config_key(config)
    async with self._state_lock:
        self._state_cache[key] = state
        logger.debug(f"Stored state cache for key: {key}")
        return state
aput_thread async
aput_thread(config, thread_info)

Store thread info asynchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
thread_info ThreadInfo

Thread information object.

required

Returns:

Name Type Description
bool bool

True if stored.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
async def aput_thread(
    self,
    config: dict[str, Any],
    thread_info: ThreadInfo,
) -> bool:
    """
    Store thread info asynchronously.

    Args:
        config (dict): Configuration dictionary.
        thread_info (ThreadInfo): Thread information object.

    Returns:
        bool: True if stored.
    """
    key = self._get_config_key(config)
    async with self._threads_lock:
        self._threads[key] = thread_info.model_dump()
        logger.debug(f"Stored thread info for key: {key}")
        return True
arelease async
arelease()

Release resources asynchronously.

Returns:

Name Type Description
bool bool

True if released.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
async def arelease(self) -> bool:
    """
    Release resources asynchronously.

    Returns:
        bool: True if released.
    """
    """Release resources asynchronously."""
    async with self._state_lock, self._messages_lock, self._threads_lock:
        self._states.clear()
        self._state_cache.clear()
        self._messages.clear()
        self._message_metadata.clear()
        self._threads.clear()
        logger.info("Released all in-memory resources")
        return True
asetup async
asetup()

Asynchronous setup method. No setup required for in-memory checkpointer.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
68
69
70
71
72
async def asetup(self) -> Any:
    """
    Asynchronous setup method. No setup required for in-memory checkpointer.
    """
    logger.debug("InMemoryCheckpointer async setup not required")
clean_thread
clean_thread(config)

Clean/delete thread synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Name Type Description
bool bool

True if cleaned.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
def clean_thread(self, config: dict[str, Any]) -> bool:
    """
    Clean/delete thread synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        bool: True if cleaned.
    """
    """Clean/delete thread synchronously."""
    key = self._get_config_key(config)
    if key in self._threads:
        del self._threads[key]
        logger.debug(f"Cleaned thread for key: {key}")
        return True
    return False
clear_state
clear_state(config)

Clear state synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Name Type Description
bool bool

True if cleared.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
def clear_state(self, config: dict[str, Any]) -> bool:
    """
    Clear state synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        bool: True if cleared.
    """
    """Clear state synchronously."""
    key = self._get_config_key(config)
    if key in self._states:
        del self._states[key]
        logger.debug(f"Cleared state for key: {key}")
    return True
delete_message
delete_message(config, message_id)

Delete a specific message synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
message_id str | int

Message identifier.

required

Returns:

Name Type Description
bool bool

True if deleted.

Raises:

Type Description
IndexError

If message not found.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
def delete_message(self, config: dict[str, Any], message_id: str | int) -> bool:
    """
    Delete a specific message synchronously.

    Args:
        config (dict): Configuration dictionary.
        message_id (str|int): Message identifier.

    Returns:
        bool: True if deleted.

    Raises:
        IndexError: If message not found.
    """
    """Delete a specific message synchronously."""
    key = self._get_config_key(config)
    messages = self._messages.get(key, [])
    for msg in messages:
        if msg.message_id == message_id:
            messages.remove(msg)
            logger.debug(f"Deleted message with ID {message_id} for key: {key}")
            return True
    raise IndexError(f"Message with ID {message_id} not found for config key: {key}")
get_message
get_message(config, message_id)

Retrieve a specific message synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Name Type Description
Message Message

Latest message object.

Raises:

Type Description
IndexError

If no messages found.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
def get_message(self, config: dict[str, Any], message_id: str | int) -> Message:
    """
    Retrieve a specific message synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        Message: Latest message object.

    Raises:
        IndexError: If no messages found.
    """
    """Retrieve the latest message synchronously."""
    key = self._get_config_key(config)
    messages = self._messages.get(key, [])
    for msg in messages:
        if msg.message_id == message_id:
            return msg
    raise IndexError(f"Message with ID {message_id} not found for config key: {key}")
get_state
get_state(config)

Retrieve state synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
StateT | None

StateT | None: Retrieved state or None.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
def get_state(self, config: dict[str, Any]) -> StateT | None:
    """
    Retrieve state synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        StateT | None: Retrieved state or None.
    """
    """Retrieve state synchronously."""
    key = self._get_config_key(config)
    state = self._states.get(key)
    logger.debug(f"Retrieved state for key: {key}, found: {state is not None}")
    return state
get_state_cache
get_state_cache(config)

Retrieve state cache synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
StateT | None

StateT | None: Cached state or None.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
def get_state_cache(self, config: dict[str, Any]) -> StateT | None:
    """
    Retrieve state cache synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        StateT | None: Cached state or None.
    """
    """Retrieve state cache synchronously."""
    key = self._get_config_key(config)
    cache = self._state_cache.get(key)
    logger.debug(f"Retrieved state cache for key: {key}, found: {cache is not None}")
    return cache
get_thread
get_thread(config)

Retrieve thread info synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
ThreadInfo | None

ThreadInfo | None: Thread information object or None.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
def get_thread(self, config: dict[str, Any]) -> ThreadInfo | None:
    """
    Retrieve thread info synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        ThreadInfo | None: Thread information object or None.
    """
    """Retrieve thread info synchronously."""
    key = self._get_config_key(config)
    thread = self._threads.get(key)
    logger.debug(f"Retrieved thread for key: {key}, found: {thread is not None}")
    return ThreadInfo.model_validate(thread) if thread else None
list_messages
list_messages(config, search=None, offset=None, limit=None)

List messages synchronously with optional filtering.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
search str

Search string.

None
offset int

Offset for pagination.

None
limit int

Limit for pagination.

None

Returns:

Type Description
list[Message]

list[Message]: List of message objects.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
def list_messages(
    self,
    config: dict[str, Any],
    search: str | None = None,
    offset: int | None = None,
    limit: int | None = None,
) -> list[Message]:
    """
    List messages synchronously with optional filtering.

    Args:
        config (dict): Configuration dictionary.
        search (str, optional): Search string.
        offset (int, optional): Offset for pagination.
        limit (int, optional): Limit for pagination.

    Returns:
        list[Message]: List of message objects.
    """
    key = self._get_config_key(config)
    messages = self._messages.get(key, [])

    # Apply search filter if provided
    if search:
        messages = [
            msg
            for msg in messages
            if hasattr(msg, "content") and search.lower() in str(msg.content).lower()
        ]

    # Apply offset and limit
    start = offset or 0
    end = (start + limit) if limit else None
    return messages[start:end]
list_threads
list_threads(config, search=None, offset=None, limit=None)

List all threads synchronously with optional filtering.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
search str

Search string.

None
offset int

Offset for pagination.

None
limit int

Limit for pagination.

None

Returns:

Type Description
list[ThreadInfo]

list[ThreadInfo]: List of thread information objects.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
def list_threads(
    self,
    config: dict[str, Any],
    search: str | None = None,
    offset: int | None = None,
    limit: int | None = None,
) -> list[ThreadInfo]:
    """
    List all threads synchronously with optional filtering.

    Args:
        config (dict): Configuration dictionary.
        search (str, optional): Search string.
        offset (int, optional): Offset for pagination.
        limit (int, optional): Limit for pagination.

    Returns:
        list[ThreadInfo]: List of thread information objects.
    """
    threads = list(self._threads.values())

    # Apply search filter if provided
    if search:
        threads = [
            thread
            for thread in threads
            if any(search.lower() in str(value).lower() for value in thread.values())
        ]

    # Apply offset and limit
    start = offset or 0
    end = (start + limit) if limit else None
    return [ThreadInfo.model_validate(thread) for thread in threads[start:end]]
put_messages
put_messages(config, messages, metadata=None)

Store messages synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
messages list[Message]

List of messages to store.

required
metadata dict

Additional metadata.

None

Returns:

Name Type Description
bool bool

True if stored.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
def put_messages(
    self,
    config: dict[str, Any],
    messages: list[Message],
    metadata: dict[str, Any] | None = None,
) -> bool:
    """
    Store messages synchronously.

    Args:
        config (dict): Configuration dictionary.
        messages (list[Message]): List of messages to store.
        metadata (dict, optional): Additional metadata.

    Returns:
        bool: True if stored.
    """
    key = self._get_config_key(config)
    self._messages[key].extend(messages)
    if metadata:
        self._message_metadata[key] = metadata

    logger.debug(f"Stored {len(messages)} messages for key: {key}")
    return True
put_state
put_state(config, state)

Store state synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
state StateT

State object to store.

required

Returns:

Name Type Description
StateT StateT

The stored state object.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
def put_state(self, config: dict[str, Any], state: StateT) -> StateT:
    """
    Store state synchronously.

    Args:
        config (dict): Configuration dictionary.
        state (StateT): State object to store.

    Returns:
        StateT: The stored state object.
    """
    """Store state synchronously."""
    key = self._get_config_key(config)
    # For sync methods, we'll use a simple approach without locks
    # In a real async-first system, sync methods might not be used
    self._states[key] = state
    logger.debug(f"Stored state for key: {key}")
    return state
put_state_cache
put_state_cache(config, state)

Store state cache synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
state StateT

State object to cache.

required

Returns:

Name Type Description
StateT StateT

The cached state object.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
def put_state_cache(self, config: dict[str, Any], state: StateT) -> StateT:
    """
    Store state cache synchronously.

    Args:
        config (dict): Configuration dictionary.
        state (StateT): State object to cache.

    Returns:
        StateT: The cached state object.
    """
    """Store state cache synchronously."""
    key = self._get_config_key(config)
    self._state_cache[key] = state
    logger.debug(f"Stored state cache for key: {key}")
    return state
put_thread
put_thread(config, thread_info)

Store thread info synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
thread_info ThreadInfo

Thread information object.

required

Returns:

Name Type Description
bool bool

True if stored.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
def put_thread(self, config: dict[str, Any], thread_info: ThreadInfo) -> bool:
    """
    Store thread info synchronously.

    Args:
        config (dict): Configuration dictionary.
        thread_info (ThreadInfo): Thread information object.

    Returns:
        bool: True if stored.
    """
    """Store thread info synchronously."""
    key = self._get_config_key(config)
    self._threads[key] = thread_info.model_dump()
    logger.debug(f"Stored thread info for key: {key}")
    return True
release
release()

Release resources synchronously.

Returns:

Name Type Description
bool bool

True if released.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
def release(self) -> bool:
    """
    Release resources synchronously.

    Returns:
        bool: True if released.
    """
    """Release resources synchronously."""
    self._states.clear()
    self._state_cache.clear()
    self._messages.clear()
    self._message_metadata.clear()
    self._threads.clear()
    logger.info("Released all in-memory resources")
    return True
setup
setup()

Synchronous setup method. No setup required for in-memory checkpointer.

Source code in pyagenity/checkpointer/in_memory_checkpointer.py
62
63
64
65
66
def setup(self) -> Any:
    """
    Synchronous setup method. No setup required for in-memory checkpointer.
    """
    logger.debug("InMemoryCheckpointer setup not required")
pg_checkpointer

Classes:

Name Description
PgCheckpointer

Implements a checkpointer using PostgreSQL and Redis for persistent and cached state management.

Attributes:

Name Type Description
DEFAULT_CACHE_TTL
HAS_ASYNCPG
HAS_REDIS
ID_TYPE_MAP
StateT
logger
Attributes
DEFAULT_CACHE_TTL module-attribute
DEFAULT_CACHE_TTL = 86400
HAS_ASYNCPG module-attribute
HAS_ASYNCPG = True
HAS_REDIS module-attribute
HAS_REDIS = True
ID_TYPE_MAP module-attribute
ID_TYPE_MAP = {'string': 'VARCHAR(255)', 'int': 'SERIAL', 'bigint': 'BIGSERIAL'}
StateT module-attribute
StateT = TypeVar('StateT', bound='AgentState')
logger module-attribute
logger = getLogger(__name__)
Classes
PgCheckpointer

Bases: BaseCheckpointer[StateT]

Implements a checkpointer using PostgreSQL and Redis for persistent and cached state management.

This class provides asynchronous and synchronous methods for storing, retrieving, and managing agent states, messages, and threads. PostgreSQL is used for durable storage, while Redis provides fast caching with TTL.

Features
  • Async-first design with sync fallbacks
  • Configurable ID types (string, int, bigint)
  • Connection pooling for both PostgreSQL and Redis
  • Proper error handling and resource management
  • Schema migration support

Parameters:

Name Type Description Default
postgres_dsn str

PostgreSQL connection string.

None
pg_pool Any

Existing asyncpg Pool instance.

None
pool_config dict

Configuration for new pg pool creation.

None
redis_url str

Redis connection URL.

None
redis Any

Existing Redis instance.

None
redis_pool Any

Existing Redis ConnectionPool.

None
redis_pool_config dict

Configuration for new redis pool creation.

None
**kwargs

Additional configuration options: - user_id_type: Type for user_id fields ('string', 'int', 'bigint') - cache_ttl: Redis cache TTL in seconds - release_resources: Whether to release resources on cleanup

{}

Raises:

Type Description
ImportError

If required dependencies are missing.

ValueError

If required connection details are missing.

Methods:

Name Description
__init__

Initializes PgCheckpointer with PostgreSQL and Redis connections.

aclean_thread

Clean/delete a thread and all associated data.

aclear_state

Clear state from PostgreSQL and Redis cache.

adelete_message

Delete a message by ID.

aget_message

Retrieve a single message by ID.

aget_state

Retrieve state from PostgreSQL.

aget_state_cache

Get state from Redis cache, fallback to PostgreSQL if miss.

aget_thread

Get thread information.

alist_messages

List messages for a thread with optional search and pagination.

alist_threads

List threads for a user with optional search and pagination.

aput_messages

Store messages in PostgreSQL.

aput_state

Store state in PostgreSQL and optionally cache in Redis.

aput_state_cache

Cache state in Redis with TTL.

aput_thread

Create or update thread information.

arelease

Clean up connections and resources.

asetup

Asynchronous setup method. Initializes database schema.

clean_thread

Clean/delete thread synchronously.

clear_state

Clear agent state synchronously.

delete_message

Delete a specific message synchronously.

get_message

Retrieve a specific message synchronously.

get_state

Retrieve agent state synchronously.

get_state_cache

Retrieve agent state from cache synchronously.

get_thread

Retrieve thread info synchronously.

list_messages

List messages synchronously with optional filtering.

list_threads

List threads synchronously with optional filtering.

put_messages

Store messages synchronously.

put_state

Store agent state synchronously.

put_state_cache

Store agent state in cache synchronously.

put_thread

Store thread info synchronously.

release

Release resources synchronously.

setup

Synchronous setup method for checkpointer.

Attributes:

Name Type Description
cache_ttl
id_type
redis
release_resources
schema
user_id_type
Source code in pyagenity/checkpointer/pg_checkpointer.py
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
class PgCheckpointer(BaseCheckpointer[StateT]):
    """
    Implements a checkpointer using PostgreSQL and Redis for persistent and cached state management.

    This class provides asynchronous and synchronous methods for storing, retrieving, and managing
    agent states, messages, and threads. PostgreSQL is used for durable storage, while Redis
    provides fast caching with TTL.

    Features:
        - Async-first design with sync fallbacks
        - Configurable ID types (string, int, bigint)
        - Connection pooling for both PostgreSQL and Redis
        - Proper error handling and resource management
        - Schema migration support

    Args:
        postgres_dsn (str, optional): PostgreSQL connection string.
        pg_pool (Any, optional): Existing asyncpg Pool instance.
        pool_config (dict, optional): Configuration for new pg pool creation.
        redis_url (str, optional): Redis connection URL.
        redis (Any, optional): Existing Redis instance.
        redis_pool (Any, optional): Existing Redis ConnectionPool.
        redis_pool_config (dict, optional): Configuration for new redis pool creation.
        **kwargs: Additional configuration options:
            - user_id_type: Type for user_id fields ('string', 'int', 'bigint')
            - cache_ttl: Redis cache TTL in seconds
            - release_resources: Whether to release resources on cleanup

    Raises:
        ImportError: If required dependencies are missing.
        ValueError: If required connection details are missing.
    """

    def __init__(
        self,
        # postgress connection details
        postgres_dsn: str | None = None,
        pg_pool: Any | None = None,
        pool_config: dict | None = None,
        # redis connection details
        redis_url: str | None = None,
        redis: Any | None = None,
        redis_pool: Any | None = None,
        redis_pool_config: dict | None = None,
        # database schema
        schema: str = "public",
        # other configurations - combine to reduce args
        **kwargs,
    ):
        """
        Initializes PgCheckpointer with PostgreSQL and Redis connections.

        Args:
            postgres_dsn (str, optional): PostgreSQL connection string.
            pg_pool (Any, optional): Existing asyncpg Pool instance.
            pool_config (dict, optional): Configuration for new pg pool creation.
            redis_url (str, optional): Redis connection URL.
            redis (Any, optional): Existing Redis instance.
            redis_pool (Any, optional): Existing Redis ConnectionPool.
            redis_pool_config (dict, optional): Configuration for new redis pool creation.
            schema (str, optional): PostgreSQL schema name. Defaults to "public".
            **kwargs: Additional configuration options.

        Raises:
            ImportError: If required dependencies are missing.
            ValueError: If required connection details are missing.
        """
        # Check for required dependencies
        if not HAS_ASYNCPG:
            raise ImportError(
                "PgCheckpointer requires 'asyncpg' package. "
                "Install with: pip install pyagenity[pg_checkpoint]"
            )

        if not HAS_REDIS:
            raise ImportError(
                "PgCheckpointer requires 'redis' package. "
                "Install with: pip install pyagenity[pg_checkpoint]"
            )

        self.user_id_type = kwargs.get("user_id_type", "string")
        # allow explicit override via kwargs, fallback to InjectQ, then default
        self.id_type = kwargs.get(
            "id_type", InjectQ.get_instance().try_get("generated_id_type", "string")
        )
        self.cache_ttl = kwargs.get("cache_ttl", DEFAULT_CACHE_TTL)
        self.release_resources = kwargs.get("release_resources", False)

        # Validate schema name to prevent SQL injection
        if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", schema):
            raise ValueError(
                f"Invalid schema name: {schema}. Schema must match pattern ^[a-zA-Z_][a-zA-Z0-9_]*$"
            )
        self.schema = schema

        self._schema_initialized = False
        self._loop: asyncio.AbstractEventLoop | None = None

        # Store pool configuration for lazy initialization
        self._pg_pool_config = {
            "pg_pool": pg_pool,
            "postgres_dsn": postgres_dsn,
            "pool_config": pool_config or {},
        }

        # Initialize pool immediately if provided, otherwise defer
        if pg_pool is not None:
            self._pg_pool = pg_pool
        else:
            self._pg_pool = None

        # Now check and initialize connections
        if not pg_pool and not postgres_dsn:
            raise ValueError("Either postgres_dsn or pg_pool must be provided.")

        if not redis and not redis_url and not redis_pool:
            raise ValueError("Either redis_url, redis_pool or redis instance must be provided.")

        # Initialize Redis connection (synchronous)
        self.redis = self._create_redis_pool(redis, redis_pool, redis_url, redis_pool_config or {})

    def _create_redis_pool(
        self,
        redis: Any | None,
        redis_pool: Any | None,
        redis_url: str | None,
        redis_pool_config: dict,
    ) -> Any:
        """
        Create or use an existing Redis connection.

        Args:
            redis (Any, optional): Existing Redis instance.
            redis_pool (Any, optional): Existing Redis ConnectionPool.
            redis_url (str, optional): Redis connection URL.
            redis_pool_config (dict): Configuration for new redis pool creation.

        Returns:
            Redis: Redis connection instance.

        Raises:
            ValueError: If redis_url is not provided when creating a new connection.
        """
        if redis:
            return redis

        if redis_pool:
            return Redis(connection_pool=redis_pool)  # type: ignore

        # as we are creating new pool, redis_url must be provided
        # and we will release the resources if needed
        if not redis_url:
            raise ValueError("redis_url must be provided when creating new Redis connection")

        self.release_resources = True
        return Redis(
            connection_pool=ConnectionPool.from_url(  # type: ignore
                redis_url,
                **redis_pool_config,
            )
        )

    def _get_table_name(self, table: str) -> str:
        """
        Get the schema-qualified table name.

        Args:
            table (str): The base table name (e.g., 'threads', 'states', 'messages')

        Returns:
            str: The schema-qualified table name (e.g., '"public"."threads"')
        """
        # Validate table name to prevent SQL injection
        if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", table):
            raise ValueError(
                f"Invalid table name: {table}. Table must match pattern ^[a-zA-Z_][a-zA-Z0-9_]*$"
            )
        return f'"{self.schema}"."{table}"'

    def _create_pg_pool(self, pg_pool: Any, postgres_dsn: str | None, pool_config: dict) -> Any:
        """
        Create or use an existing PostgreSQL connection pool.

        Args:
            pg_pool (Any, optional): Existing asyncpg Pool instance.
            postgres_dsn (str, optional): PostgreSQL connection string.
            pool_config (dict): Configuration for new pg pool creation.

        Returns:
            Pool: PostgreSQL connection pool.
        """
        if pg_pool:
            return pg_pool
        # as we are creating new pool, postgres_dsn must be provided
        # and we will release the resources if needed
        self.release_resources = True
        return asyncpg.create_pool(dsn=postgres_dsn, **pool_config)  # type: ignore

    async def _get_pg_pool(self) -> Any:
        """
        Get PostgreSQL pool, creating it if necessary.

        Returns:
            Pool: PostgreSQL connection pool.
        """
        """Get PostgreSQL pool, creating it if necessary."""
        if self._pg_pool is None:
            config = self._pg_pool_config
            self._pg_pool = await self._create_pg_pool(
                config["pg_pool"], config["postgres_dsn"], config["pool_config"]
            )
        return self._pg_pool

    def _get_sql_type(self, type_name: str) -> str:
        """
        Get SQL type for given configuration type.

        Args:
            type_name (str): Type name ('string', 'int', 'bigint').

        Returns:
            str: Corresponding SQL type.
        """
        """Get SQL type for given configuration type."""
        return ID_TYPE_MAP.get(type_name, "VARCHAR(255)")

    def _get_json_serializer(self):
        """Get optimal JSON serializer based on FAST_JSON env var."""
        if os.environ.get("FAST_JSON", "0") == "1":
            try:
                import orjson

                return orjson.dumps
            except ImportError:
                try:
                    import msgspec  # type: ignore

                    return msgspec.json.encode
                except ImportError:
                    pass
        return json.dumps

    def _get_current_schema_version(self) -> int:
        """Return current expected schema version."""
        return 1  # increment when schema changes

    def _build_create_tables_sql(self) -> list[str]:
        """
        Build SQL statements for table creation with dynamic ID types.

        Returns:
            list[str]: List of SQL statements for table creation.
        """
        """Build SQL statements for table creation with dynamic ID types."""
        thread_id_type = self._get_sql_type(self.id_type)
        user_id_type = self._get_sql_type(self.user_id_type)
        message_id_type = self._get_sql_type(self.id_type)

        # For AUTO INCREMENT types, we need to handle primary key differently
        thread_pk = (
            "thread_id SERIAL PRIMARY KEY"
            if self.id_type == "int"
            else f"thread_id {thread_id_type} PRIMARY KEY"
        )
        message_pk = (
            "message_id SERIAL PRIMARY KEY"
            if self.id_type == "int"
            else f"message_id {message_id_type} PRIMARY KEY"
        )

        return [
            # Schema version tracking table
            f"""
            CREATE TABLE IF NOT EXISTS {self._get_table_name("schema_version")} (
                version INT PRIMARY KEY,
                applied_at TIMESTAMPTZ DEFAULT NOW()
            )
            """,
            # Create message role enum (safe for older Postgres versions)
            (
                "DO $$\n"
                "BEGIN\n"
                "    CREATE TYPE message_role AS ENUM ('user', 'assistant', 'system', 'tool');\n"
                "EXCEPTION\n"
                "    WHEN duplicate_object THEN NULL;\n"
                "END$$;"
            ),
            # Create threads table
            f"""
            CREATE TABLE IF NOT EXISTS {self._get_table_name("threads")} (
                {thread_pk},
                thread_name VARCHAR(255),
                user_id {user_id_type} NOT NULL,
                created_at TIMESTAMPTZ DEFAULT NOW(),
                updated_at TIMESTAMPTZ DEFAULT NOW(),
                meta JSONB DEFAULT '{{}}'::jsonb
            )
            """,
            # Create states table
            f"""
            CREATE TABLE IF NOT EXISTS {self._get_table_name("states")} (
                state_id SERIAL PRIMARY KEY,
                thread_id {thread_id_type} NOT NULL
                    REFERENCES {self._get_table_name("threads")}(thread_id)
                    ON DELETE CASCADE,
                state_data JSONB NOT NULL,
                created_at TIMESTAMPTZ DEFAULT NOW(),
                updated_at TIMESTAMPTZ DEFAULT NOW(),
                meta JSONB DEFAULT '{{}}'::jsonb
            )
            """,
            # Create messages table
            f"""
            CREATE TABLE IF NOT EXISTS {self._get_table_name("messages")} (
                {message_pk},
                thread_id {thread_id_type} NOT NULL
                    REFERENCES {self._get_table_name("threads")}(thread_id)
                    ON DELETE CASCADE,
                role message_role NOT NULL,
                content TEXT NOT NULL,
                tool_calls JSONB,
                tool_call_id VARCHAR(255),
                reasoning TEXT,
                created_at TIMESTAMPTZ DEFAULT NOW(),
                updated_at TIMESTAMPTZ DEFAULT NOW(),
                total_tokens INT DEFAULT 0,
                usages JSONB DEFAULT '{{}}'::jsonb,
                meta JSONB DEFAULT '{{}}'::jsonb
            )
            """,
            # Create indexes
            f"CREATE INDEX IF NOT EXISTS idx_threads_user_id ON "
            f"{self._get_table_name('threads')}(user_id)",
            f"CREATE INDEX IF NOT EXISTS idx_states_thread_id ON "
            f"{self._get_table_name('states')}(thread_id)",
            f"CREATE INDEX IF NOT EXISTS idx_messages_thread_id ON "
            f"{self._get_table_name('messages')}(thread_id)",
        ]

    async def _check_and_apply_schema_version(self, conn) -> None:
        """Check current version and update if needed."""
        try:
            # Check if schema version exists
            row = await conn.fetchrow(
                f"SELECT version FROM {self._get_table_name('schema_version')} "  # noqa: S608
                f"ORDER BY version DESC LIMIT 1"
            )
            current_version = row["version"] if row else 0
            target_version = self._get_current_schema_version()

            if current_version < target_version:
                logger.info(
                    "Upgrading schema from version %d to %d", current_version, target_version
                )
                # Insert new version
                await conn.execute(
                    f"INSERT INTO {self._get_table_name('schema_version')} (version) VALUES ($1)",  # noqa: S608
                    target_version,
                )
        except Exception as e:
            logger.debug("Schema version check failed (expected on first run): %s", e)
            # Insert initial version
            with suppress(Exception):
                await conn.execute(
                    f"INSERT INTO {self._get_table_name('schema_version')} (version) VALUES ($1)",  # noqa: S608
                    self._get_current_schema_version(),
                )

    async def _initialize_schema(self) -> None:
        """
        Initialize database schema if not already done.

        Returns:
            None
        """
        """Initialize database schema if not already done."""
        if self._schema_initialized:
            return

        logger.debug(
            "Initializing database schema with types: id_type=%s, user_id_type=%s",
            self.id_type,
            self.user_id_type,
        )

        async with (await self._get_pg_pool()).acquire() as conn:
            try:
                sql_statements = self._build_create_tables_sql()
                for sql in sql_statements:
                    logger.debug("Executing SQL: %s", sql.strip())
                    await conn.execute(sql)

                # Check and apply schema version tracking
                await self._check_and_apply_schema_version(conn)

                self._schema_initialized = True
                logger.debug("Database schema initialized successfully")
            except Exception as e:
                logger.error("Failed to initialize database schema: %s", e)
                raise

    ###########################
    #### SETUP METHODS ########
    ###########################

    async def asetup(self) -> Any:
        """
        Asynchronous setup method. Initializes database schema.

        Returns:
            Any: True if setup completed.
        """
        """Async setup method - initializes database schema."""
        logger.info(
            "Setting up PgCheckpointer (async)",
            extra={
                "id_type": self.id_type,
                "user_id_type": self.user_id_type,
                "schema": self.schema,
            },
        )
        await self._initialize_schema()
        logger.info("PgCheckpointer setup completed")
        return True

    ###########################
    #### HELPER METHODS #######
    ###########################

    def _validate_config(self, config: dict[str, Any]) -> tuple[str | int, str | int]:
        """
        Extract and validate thread_id and user_id from config.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            tuple: (thread_id, user_id)

        Raises:
            ValueError: If required fields are missing.
        """
        """Extract and validate thread_id and user_id from config."""
        thread_id = config.get("thread_id")
        user_id = config.get("user_id")
        if not user_id:
            raise ValueError("user_id must be provided in config")

        if not thread_id:
            raise ValueError("Both thread_id must be provided in config")

        return thread_id, user_id

    def _get_thread_key(
        self,
        thread_id: str | int,
        user_id: str | int,
    ) -> str:
        """
        Get Redis cache key for thread state.

        Args:
            thread_id (str|int): Thread identifier.
            user_id (str|int): User identifier.

        Returns:
            str: Redis cache key.
        """
        return f"state_cache:{thread_id}:{user_id}"

    def _serialize_state(self, state: StateT) -> str:
        """
        Serialize state to JSON string for storage.

        Args:
            state (StateT): State object.

        Returns:
            str: JSON string.
        """
        """Serialize state to JSON string for storage."""

        def enum_handler(obj):
            if isinstance(obj, Enum):
                return obj.value
            return str(obj)

        return json.dumps(state.model_dump(), default=enum_handler)

    def _serialize_state_fast(self, state: StateT) -> str:
        """
        Serialize state using fast JSON serializer if available.

        Args:
            state (StateT): State object.

        Returns:
            str: JSON string.
        """
        serializer = self._get_json_serializer()

        def enum_handler(obj):
            if isinstance(obj, Enum):
                return obj.value
            return str(obj)

        data = state.model_dump()

        # Use fast serializer if available, otherwise fall back to json.dumps with enum handling
        if serializer is json.dumps:
            return json.dumps(data, default=enum_handler)

        # Fast serializers (orjson, msgspec) may not support default handlers
        # Pre-process enums to avoid issues
        result = serializer(data)
        # Ensure we return a string (orjson returns bytes)
        return result.decode("utf-8") if isinstance(result, bytes) else str(result)

    def _deserialize_state(
        self,
        data: Any,
        state_class: type[StateT],
    ) -> StateT:
        """
        Deserialize JSON/JSONB back to state object.

        Args:
            data (Any): JSON string or dict/list.
            state_class (type): State class type.

        Returns:
            StateT: Deserialized state object.

        Raises:
            Exception: If deserialization fails.
        """
        try:
            if isinstance(data, bytes | bytearray):
                data = data.decode()
            if isinstance(data, str):
                return state_class.model_validate(json.loads(data))
            # Assume it's already a dict/list
            return state_class.model_validate(data)
        except Exception:
            # Last-resort: coerce to string and attempt parse, else raise
            if isinstance(data, str):
                return state_class.model_validate(json.loads(data))
            raise

    async def _retry_on_connection_error(
        self,
        operation,
        *args,
        max_retries=3,
        **kwargs,
    ):
        """
        Retry database operations on connection errors.

        Args:
            operation: Callable operation.
            *args: Arguments.
            max_retries (int): Maximum retries.
            **kwargs: Keyword arguments.

        Returns:
            Any: Result of operation or None.

        Raises:
            Exception: If all retries fail.
        """
        last_exception = None

        # Define exception types to catch (only if asyncpg is available)
        exceptions_to_catch: list[type[Exception]] = [ConnectionError]
        if HAS_ASYNCPG and asyncpg:
            exceptions_to_catch.extend([asyncpg.PostgresConnectionError, asyncpg.InterfaceError])

        exception_tuple = tuple(exceptions_to_catch)

        for attempt in range(max_retries):
            try:
                return await operation(*args, **kwargs)
            except exception_tuple as e:
                last_exception = e
                if attempt < max_retries - 1:
                    wait_time = 2**attempt  # exponential backoff
                    logger.warning(
                        "Database connection error on attempt %d/%d, retrying in %ds: %s",
                        attempt + 1,
                        max_retries,
                        wait_time,
                        e,
                    )
                    await asyncio.sleep(wait_time)
                    continue

                logger.error("Failed after %d attempts: %s", max_retries, e)
                break
            except Exception as e:
                # Don't retry on non-connection errors
                logger.error("Non-retryable error: %s", e)
                raise

        if last_exception:
            raise last_exception
        return None

    ###########################
    #### STATE METHODS ########
    ###########################

    async def aput_state(
        self,
        config: dict[str, Any],
        state: StateT,
    ) -> StateT:
        """
        Store state in PostgreSQL and optionally cache in Redis.

        Args:
            config (dict): Configuration dictionary.
            state (StateT): State object to store.

        Returns:
            StateT: The stored state object.

        Raises:
            StorageError: If storing fails.
        """
        # Ensure schema is initialized before accessing tables
        await self._initialize_schema()
        thread_id, user_id = self._validate_config(config)

        logger.debug("Storing state for thread_id=%s, user_id=%s", thread_id, user_id)
        metrics.counter("pg_checkpointer.save_state.attempts").inc()

        with metrics.timer("pg_checkpointer.save_state.duration"):
            try:
                # Ensure thread exists first
                await self._ensure_thread_exists(thread_id, user_id, config)

                # Store in PostgreSQL with retry logic
                state_json = self._serialize_state_fast(state)

                async def _store_state():
                    async with (await self._get_pg_pool()).acquire() as conn:
                        await conn.execute(
                            f"""
                            INSERT INTO {self._get_table_name("states")}
                                (thread_id, state_data, meta)
                            VALUES ($1, $2, $3)
                            ON CONFLICT DO NOTHING
                            """,  # noqa: S608
                            thread_id,
                            state_json,
                            json.dumps(config.get("meta", {})),
                        )

                await self._retry_on_connection_error(_store_state, max_retries=3)
                logger.debug("State stored successfully for thread_id=%s", thread_id)
                metrics.counter("pg_checkpointer.save_state.success").inc()
                return state

            except Exception as e:
                metrics.counter("pg_checkpointer.save_state.error").inc()
                logger.error("Failed to store state for thread_id=%s: %s", thread_id, e)
                if asyncpg and hasattr(asyncpg, "ConnectionDoesNotExistError"):
                    connection_errors = (
                        asyncpg.ConnectionDoesNotExistError,
                        asyncpg.InterfaceError,
                    )
                    if isinstance(e, connection_errors):
                        raise TransientStorageError(f"Connection issue storing state: {e}") from e
                raise StorageError(f"Failed to store state: {e}") from e

    async def aget_state(self, config: dict[str, Any]) -> StateT | None:
        """
        Retrieve state from PostgreSQL.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            StateT | None: Retrieved state or None.

        Raises:
            Exception: If retrieval fails.
        """
        """Retrieve state from PostgreSQL."""
        # Ensure schema is initialized before accessing tables
        await self._initialize_schema()
        thread_id, user_id = self._validate_config(config)
        state_class = config.get("state_class", AgentState)

        logger.debug("Retrieving state for thread_id=%s, user_id=%s", thread_id, user_id)

        try:

            async def _get_state():
                async with (await self._get_pg_pool()).acquire() as conn:
                    return await conn.fetchrow(
                        f"""
                        SELECT state_data FROM {self._get_table_name("states")}
                        WHERE thread_id = $1
                        ORDER BY created_at DESC
                        LIMIT 1
                        """,  # noqa: S608
                        thread_id,
                    )

            row = await self._retry_on_connection_error(_get_state, max_retries=3)

            if row:
                logger.debug("State found for thread_id=%s", thread_id)
                return self._deserialize_state(row["state_data"], state_class)

            logger.debug("No state found for thread_id=%s", thread_id)
            return None

        except Exception as e:
            logger.error("Failed to retrieve state for thread_id=%s: %s", thread_id, e)
            raise

    async def aclear_state(self, config: dict[str, Any]) -> Any:
        """
        Clear state from PostgreSQL and Redis cache.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            Any: None

        Raises:
            Exception: If clearing fails.
        """
        """Clear state from PostgreSQL and Redis cache."""
        # Ensure schema is initialized before accessing tables
        await self._initialize_schema()
        thread_id, user_id = self._validate_config(config)

        logger.debug("Clearing state for thread_id=%s, user_id=%s", thread_id, user_id)

        try:
            # Clear from PostgreSQL with retry logic
            async def _clear_state():
                async with (await self._get_pg_pool()).acquire() as conn:
                    await conn.execute(
                        f"DELETE FROM {self._get_table_name('states')} WHERE thread_id = $1",  # noqa: S608
                        thread_id,
                    )

            await self._retry_on_connection_error(_clear_state, max_retries=3)

            # Clear from Redis cache
            cache_key = self._get_thread_key(thread_id, user_id)
            await self.redis.delete(cache_key)

            logger.debug("State cleared for thread_id=%s", thread_id)

        except Exception as e:
            logger.error("Failed to clear state for thread_id=%s: %s", thread_id, e)
            raise

    async def aput_state_cache(self, config: dict[str, Any], state: StateT) -> Any | None:
        """
        Cache state in Redis with TTL.

        Args:
            config (dict): Configuration dictionary.
            state (StateT): State object to cache.

        Returns:
            Any | None: True if cached, None if failed.
        """
        """Cache state in Redis with TTL."""
        # No DB access, but keep consistent
        thread_id, user_id = self._validate_config(config)

        logger.debug("Caching state for thread_id=%s, user_id=%s", thread_id, user_id)

        try:
            cache_key = self._get_thread_key(thread_id, user_id)
            state_json = self._serialize_state(state)
            await self.redis.setex(cache_key, self.cache_ttl, state_json)
            logger.debug("State cached with key=%s, ttl=%d", cache_key, self.cache_ttl)
            return True

        except Exception as e:
            logger.error("Failed to cache state for thread_id=%s: %s", thread_id, e)
            # Don't raise - caching is optional
            return None

    async def aget_state_cache(self, config: dict[str, Any]) -> StateT | None:
        """
        Get state from Redis cache, fallback to PostgreSQL if miss.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            StateT | None: State object or None.
        """
        """Get state from Redis cache, fallback to PostgreSQL if miss."""
        # Schema might be needed if we fall back to DB
        await self._initialize_schema()
        thread_id, user_id = self._validate_config(config)
        state_class = config.get("state_class", AgentState)

        logger.debug("Getting cached state for thread_id=%s, user_id=%s", thread_id, user_id)

        try:
            # Try Redis first
            cache_key = self._get_thread_key(thread_id, user_id)
            cached_data = await self.redis.get(cache_key)

            if cached_data:
                logger.debug("Cache hit for thread_id=%s", thread_id)
                return self._deserialize_state(cached_data.decode(), state_class)

            # Cache miss - fallback to PostgreSQL
            logger.debug("Cache miss for thread_id=%s, falling back to PostgreSQL", thread_id)
            state = await self.aget_state(config)

            # Cache the result for next time
            if state:
                await self.aput_state_cache(config, state)

            return state

        except Exception as e:
            logger.error("Failed to get cached state for thread_id=%s: %s", thread_id, e)
            # Fallback to PostgreSQL on error
            return await self.aget_state(config)

    async def _ensure_thread_exists(
        self,
        thread_id: str | int,
        user_id: str | int,
        config: dict[str, Any],
    ) -> None:
        """
        Ensure thread exists in database, create if not.

        Args:
            thread_id (str|int): Thread identifier.
            user_id (str|int): User identifier.
            config (dict): Configuration dictionary.

        Returns:
            None

        Raises:
            Exception: If creation fails.
        """
        # Ensure schema is initialized before accessing tables
        await self._initialize_schema()
        try:

            async def _check_and_create_thread():
                async with (await self._get_pg_pool()).acquire() as conn:
                    exists = await conn.fetchval(
                        f"SELECT 1 FROM {self._get_table_name('threads')} "  # noqa: S608
                        f"WHERE thread_id = $1 AND user_id = $2",
                        thread_id,
                        user_id,
                    )

                    if not exists:
                        thread_name = config.get("thread_name", f"Thread {thread_id}")
                        meta = json.dumps(config.get("thread_meta", {}))
                        await conn.execute(
                            f"""
                            INSERT INTO {self._get_table_name("threads")}
                                (thread_id, thread_name, user_id, meta)
                            VALUES ($1, $2, $3, $4)
                            ON CONFLICT DO NOTHING
                            """,  # noqa: S608
                            thread_id,
                            thread_name,
                            user_id,
                            meta,
                        )
                        logger.debug("Created thread: thread_id=%s, user_id=%s", thread_id, user_id)

            await self._retry_on_connection_error(_check_and_create_thread, max_retries=3)

        except Exception as e:
            logger.error("Failed to ensure thread exists: %s", e)
            raise

    ###########################
    #### MESSAGE METHODS ######
    ###########################

    async def aput_messages(
        self,
        config: dict[str, Any],
        messages: list[Message],
        metadata: dict[str, Any] | None = None,
    ) -> Any:
        """
        Store messages in PostgreSQL.

        Args:
            config (dict): Configuration dictionary.
            messages (list[Message]): List of messages to store.
            metadata (dict, optional): Additional metadata.

        Returns:
            Any: None

        Raises:
            Exception: If storing fails.
        """
        # Ensure schema is initialized before accessing tables
        await self._initialize_schema()
        thread_id, user_id = self._validate_config(config)

        if not messages:
            logger.debug("No messages to store for thread_id=%s", thread_id)
            return

        logger.debug("Storing %d messages for thread_id=%s", len(messages), thread_id)

        try:
            # Ensure thread exists
            await self._ensure_thread_exists(thread_id, user_id, config)

            # Store messages in batch with retry logic
            async def _store_messages():
                async with (await self._get_pg_pool()).acquire() as conn, conn.transaction():
                    for message in messages:
                        # content_value = message.content
                        # if not isinstance(content_value, str):
                        #     try:
                        #         content_value = json.dumps(content_value)
                        #     except Exception:
                        #         content_value = str(content_value)
                        await conn.execute(
                            f"""
                                INSERT INTO {self._get_table_name("messages")} (
                                    message_id, thread_id, role, content, tool_calls,
                                    tool_call_id, reasoning, total_tokens, usages, meta
                                )
                                VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
                                ON CONFLICT (message_id) DO UPDATE SET
                                    content = EXCLUDED.content,
                                    reasoning = EXCLUDED.reasoning,
                                    usages = EXCLUDED.usages,
                                    updated_at = NOW()
                                """,  # noqa: S608
                            message.message_id,
                            thread_id,
                            message.role,
                            json.dumps(
                                [block.model_dump(mode="json") for block in message.content]
                            ),
                            json.dumps(message.tools_calls) if message.tools_calls else None,
                            getattr(message, "tool_call_id", None),
                            message.reasoning,
                            message.usages.total_tokens if message.usages else 0,
                            json.dumps(message.usages.model_dump()) if message.usages else None,
                            json.dumps({**(metadata or {}), **(message.metadata or {})}),
                        )

            await self._retry_on_connection_error(_store_messages, max_retries=3)
            logger.debug("Stored %d messages for thread_id=%s", len(messages), thread_id)

        except Exception as e:
            logger.error("Failed to store messages for thread_id=%s: %s", thread_id, e)
            raise

    async def aget_message(self, config: dict[str, Any], message_id: str | int) -> Message:
        """
        Retrieve a single message by ID.

        Args:
            config (dict): Configuration dictionary.
            message_id (str|int): Message identifier.

        Returns:
            Message: Retrieved message object.

        Raises:
            Exception: If retrieval fails.
        """
        """Retrieve a single message by ID."""
        # Ensure schema is initialized before accessing tables
        await self._initialize_schema()
        thread_id = config.get("thread_id")

        logger.debug("Retrieving message_id=%s for thread_id=%s", message_id, thread_id)

        try:

            async def _get_message():
                async with (await self._get_pg_pool()).acquire() as conn:
                    query = f"""
                        SELECT message_id, thread_id, role, content, tool_calls,
                               tool_call_id, reasoning, created_at, total_tokens,
                               usages, meta
                        FROM {self._get_table_name("messages")}
                        WHERE message_id = $1
                    """  # noqa: S608
                    if thread_id:
                        query += " AND thread_id = $2"
                        return await conn.fetchrow(query, message_id, thread_id)
                    return await conn.fetchrow(query, message_id)

            row = await self._retry_on_connection_error(_get_message, max_retries=3)

            if not row:
                raise ValueError(f"Message not found: {message_id}")

            return self._row_to_message(row)

        except Exception as e:
            logger.error("Failed to retrieve message_id=%s: %s", message_id, e)
            raise

    async def alist_messages(
        self,
        config: dict[str, Any],
        search: str | None = None,
        offset: int | None = None,
        limit: int | None = None,
    ) -> list[Message]:
        """
        List messages for a thread with optional search and pagination.

        Args:
            config (dict): Configuration dictionary.
            search (str, optional): Search string.
            offset (int, optional): Offset for pagination.
            limit (int, optional): Limit for pagination.

        Returns:
            list[Message]: List of message objects.

        Raises:
            Exception: If listing fails.
        """
        # Ensure schema is initialized before accessing tables
        await self._initialize_schema()
        thread_id = config.get("thread_id")

        if not thread_id:
            raise ValueError("thread_id must be provided in config")

        logger.debug("Listing messages for thread_id=%s", thread_id)

        try:

            async def _list_messages():
                async with (await self._get_pg_pool()).acquire() as conn:
                    # Build query with optional search
                    query = f"""
                        SELECT message_id, thread_id, role, content, tool_calls,
                               tool_call_id, reasoning, created_at, total_tokens,
                               usages, meta
                        FROM {self._get_table_name("messages")}
                        WHERE thread_id = $1
                    """  # noqa: S608
                    params = [thread_id]
                    param_count = 1

                    if search:
                        param_count += 1
                        query += f" AND content ILIKE ${param_count}"
                        params.append(f"%{search}%")

                    query += " ORDER BY created_at ASC"

                    if limit:
                        param_count += 1
                        query += f" LIMIT ${param_count}"
                        params.append(limit)

                    if offset:
                        param_count += 1
                        query += f" OFFSET ${param_count}"
                        params.append(offset)

                    return await conn.fetch(query, *params)

            rows = await self._retry_on_connection_error(_list_messages, max_retries=3)
            if not rows:
                rows = []
            messages = [self._row_to_message(row) for row in rows]

            logger.debug("Found %d messages for thread_id=%s", len(messages), thread_id)
            return messages

        except Exception as e:
            logger.error("Failed to list messages for thread_id=%s: %s", thread_id, e)
            raise

    async def adelete_message(
        self,
        config: dict[str, Any],
        message_id: str | int,
    ) -> Any | None:
        """
        Delete a message by ID.

        Args:
            config (dict): Configuration dictionary.
            message_id (str|int): Message identifier.

        Returns:
            Any | None: None

        Raises:
            Exception: If deletion fails.
        """
        # Ensure schema is initialized before accessing tables
        await self._initialize_schema()
        thread_id = config.get("thread_id")

        logger.debug("Deleting message_id=%s for thread_id=%s", message_id, thread_id)

        try:

            async def _delete_message():
                async with (await self._get_pg_pool()).acquire() as conn:
                    if thread_id:
                        await conn.execute(
                            f"DELETE FROM {self._get_table_name('messages')} "  # noqa: S608
                            f"WHERE message_id = $1 AND thread_id = $2",
                            message_id,
                            thread_id,
                        )
                    else:
                        await conn.execute(
                            f"DELETE FROM {self._get_table_name('messages')} WHERE message_id = $1",  # noqa: S608
                            message_id,
                        )

            await self._retry_on_connection_error(_delete_message, max_retries=3)
            logger.debug("Deleted message_id=%s", message_id)
            return None

        except Exception as e:
            logger.error("Failed to delete message_id=%s: %s", message_id, e)
            raise

    def _row_to_message(self, row) -> Message:  # noqa: PLR0912, PLR0915
        """
        Convert database row to Message object with robust JSON handling.

        Args:
            row: Database row.

        Returns:
            Message: Message object.
        """
        from pyagenity.utils.message import TokenUsages

        # Handle usages JSONB
        usages = None
        usages_raw = row["usages"]
        if usages_raw:
            try:
                usages_dict = (
                    json.loads(usages_raw)
                    if isinstance(usages_raw, str | bytes | bytearray)
                    else usages_raw
                )
                usages = TokenUsages(**usages_dict)
            except Exception:
                usages = None

        # Handle tool_calls JSONB
        tool_calls_raw = row["tool_calls"]
        if tool_calls_raw:
            try:
                tool_calls = (
                    json.loads(tool_calls_raw)
                    if isinstance(tool_calls_raw, str | bytes | bytearray)
                    else tool_calls_raw
                )
            except Exception:
                tool_calls = None
        else:
            tool_calls = None

        # Handle meta JSONB
        meta_raw = row["meta"]
        if meta_raw:
            try:
                metadata = (
                    json.loads(meta_raw)
                    if isinstance(meta_raw, str | bytes | bytearray)
                    else meta_raw
                )
            except Exception:
                metadata = {}
        else:
            metadata = {}

        # Handle content TEXT/JSONB -> list of blocks
        content_raw = row["content"]
        content_value: list[Any] = []
        if content_raw is None:
            content_value = []
        elif isinstance(content_raw, bytes | bytearray):
            try:
                parsed = json.loads(content_raw.decode())
                if isinstance(parsed, list):
                    content_value = parsed
                elif isinstance(parsed, dict):
                    content_value = [parsed]
                else:
                    content_value = [{"type": "text", "text": str(parsed), "annotations": []}]
            except Exception:
                content_value = [
                    {"type": "text", "text": content_raw.decode(errors="ignore"), "annotations": []}
                ]
        elif isinstance(content_raw, str):
            # Try JSON parse first
            try:
                parsed = json.loads(content_raw)
                if isinstance(parsed, list):
                    content_value = parsed
                elif isinstance(parsed, dict):
                    content_value = [parsed]
                else:
                    content_value = [{"type": "text", "text": content_raw, "annotations": []}]
            except Exception:
                content_value = [{"type": "text", "text": content_raw, "annotations": []}]
        elif isinstance(content_raw, list):
            content_value = content_raw
        elif isinstance(content_raw, dict):
            content_value = [content_raw]
        else:
            content_value = [{"type": "text", "text": str(content_raw), "annotations": []}]

        return Message(
            message_id=row["message_id"],
            role=row["role"],
            content=content_value,
            tools_calls=tool_calls,
            reasoning=row["reasoning"],
            timestamp=row["created_at"],
            metadata=metadata,
            usages=usages,
        )

    ###########################
    #### THREAD METHODS #######
    ###########################

    async def aput_thread(
        self,
        config: dict[str, Any],
        thread_info: ThreadInfo,
    ) -> Any | None:
        """
        Create or update thread information.

        Args:
            config (dict): Configuration dictionary.
            thread_info (ThreadInfo): Thread information object.

        Returns:
            Any | None: None

        Raises:
            Exception: If storing fails.
        """
        # Ensure schema is initialized before accessing tables
        await self._initialize_schema()
        thread_id, user_id = self._validate_config(config)

        logger.debug("Storing thread info for thread_id=%s, user_id=%s", thread_id, user_id)

        try:
            thread_name = thread_info.thread_name or f"Thread {thread_id}"
            meta = thread_info.metadata or {}
            user_id = thread_info.user_id or user_id
            meta.update(
                {
                    "run_id": thread_info.run_id,
                }
            )

            async def _put_thread():
                async with (await self._get_pg_pool()).acquire() as conn:
                    await conn.execute(
                        f"""
                        INSERT INTO {self._get_table_name("threads")}
                            (thread_id, thread_name, user_id, meta)
                        VALUES ($1, $2, $3, $4)
                        ON CONFLICT (thread_id) DO UPDATE SET
                            thread_name = EXCLUDED.thread_name,
                            meta = EXCLUDED.meta,
                            updated_at = NOW()
                        """,  # noqa: S608
                        thread_id,
                        thread_name,
                        user_id,
                        json.dumps(meta),
                    )

            await self._retry_on_connection_error(_put_thread, max_retries=3)
            logger.debug("Thread info stored for thread_id=%s", thread_id)

        except Exception as e:
            logger.error("Failed to store thread info for thread_id=%s: %s", thread_id, e)
            raise

    async def aget_thread(
        self,
        config: dict[str, Any],
    ) -> ThreadInfo | None:
        """
        Get thread information.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            ThreadInfo | None: Thread information object or None.

        Raises:
            Exception: If retrieval fails.
        """
        # Ensure schema is initialized before accessing tables
        await self._initialize_schema()
        thread_id, user_id = self._validate_config(config)

        logger.debug("Retrieving thread info for thread_id=%s, user_id=%s", thread_id, user_id)

        try:

            async def _get_thread():
                async with (await self._get_pg_pool()).acquire() as conn:
                    return await conn.fetchrow(
                        f"""
                        SELECT thread_id, thread_name, user_id, created_at, updated_at, meta
                        FROM {self._get_table_name("threads")}
                        WHERE thread_id = $1 AND user_id = $2
                        """,  # noqa: S608
                        thread_id,
                        user_id,
                    )

            row = await self._retry_on_connection_error(_get_thread, max_retries=3)

            if row:
                meta_dict = {}
                if row["meta"]:
                    meta_dict = (
                        json.loads(row["meta"]) if isinstance(row["meta"], str) else row["meta"]
                    )
                return ThreadInfo(
                    thread_id=thread_id,
                    thread_name=row["thread_name"] if row else None,
                    user_id=user_id,
                    metadata=meta_dict,
                    run_id=meta_dict.get("run_id"),
                )

            logger.debug("Thread not found for thread_id=%s, user_id=%s", thread_id, user_id)
            return None

        except Exception as e:
            logger.error("Failed to retrieve thread info for thread_id=%s: %s", thread_id, e)
            raise

    async def alist_threads(
        self,
        config: dict[str, Any],
        search: str | None = None,
        offset: int | None = None,
        limit: int | None = None,
    ) -> list[ThreadInfo]:
        """
        List threads for a user with optional search and pagination.

        Args:
            config (dict): Configuration dictionary.
            search (str, optional): Search string.
            offset (int, optional): Offset for pagination.
            limit (int, optional): Limit for pagination.

        Returns:
            list[ThreadInfo]: List of thread information objects.

        Raises:
            Exception: If listing fails.
        """
        # Ensure schema is initialized before accessing tables
        await self._initialize_schema()
        user_id = config.get("user_id")
        user_id = user_id or "test-user"

        if not user_id:
            raise ValueError("user_id must be provided in config")

        logger.debug("Listing threads for user_id=%s", user_id)

        try:

            async def _list_threads():
                async with (await self._get_pg_pool()).acquire() as conn:
                    # Build query with optional search
                    query = f"""
                        SELECT thread_id, thread_name, user_id, created_at, updated_at, meta
                        FROM {self._get_table_name("threads")}
                        WHERE user_id = $1
                    """  # noqa: S608
                    params = [user_id]
                    param_count = 1

                    if search:
                        param_count += 1
                        query += f" AND thread_name ILIKE ${param_count}"
                        params.append(f"%{search}%")

                    query += " ORDER BY updated_at DESC"

                    if limit:
                        param_count += 1
                        query += f" LIMIT ${param_count}"
                        params.append(limit)

                    if offset:
                        param_count += 1
                        query += f" OFFSET ${param_count}"
                        params.append(offset)

                    return await conn.fetch(query, *params)

            rows = await self._retry_on_connection_error(_list_threads, max_retries=3)
            if not rows:
                rows = []

            threads = []
            for row in rows:
                meta_dict = {}
                if row["meta"]:
                    meta_dict = (
                        json.loads(row["meta"]) if isinstance(row["meta"], str) else row["meta"]
                    )
                threads.append(
                    ThreadInfo(
                        thread_id=row["thread_id"],
                        thread_name=row["thread_name"],
                        user_id=row["user_id"],
                        metadata=meta_dict,
                        run_id=meta_dict.get("run_id"),
                        updated_at=row["updated_at"],
                    )
                )
            logger.debug("Found %d threads for user_id=%s", len(threads), user_id)
            return threads

        except Exception as e:
            logger.error("Failed to list threads for user_id=%s: %s", user_id, e)
            raise

    async def aclean_thread(self, config: dict[str, Any]) -> Any | None:
        """
        Clean/delete a thread and all associated data.

        Args:
            config (dict): Configuration dictionary.

        Returns:
            Any | None: None

        Raises:
            Exception: If cleaning fails.
        """
        """Clean/delete a thread and all associated data."""
        # Ensure schema is initialized before accessing tables
        await self._initialize_schema()
        thread_id, user_id = self._validate_config(config)

        logger.debug("Cleaning thread thread_id=%s, user_id=%s", thread_id, user_id)

        try:
            # Delete thread (cascade will handle messages and states) with retry logic
            async def _clean_thread():
                async with (await self._get_pg_pool()).acquire() as conn:
                    await conn.execute(
                        f"DELETE FROM {self._get_table_name('threads')} "  # noqa: S608
                        f"WHERE thread_id = $1 AND user_id = $2",
                        thread_id,
                        user_id,
                    )

            await self._retry_on_connection_error(_clean_thread, max_retries=3)

            # Clean from Redis cache
            cache_key = self._get_thread_key(thread_id, user_id)
            await self.redis.delete(cache_key)

            logger.debug("Thread cleaned: thread_id=%s, user_id=%s", thread_id, user_id)

        except Exception as e:
            logger.error("Failed to clean thread thread_id=%s: %s", thread_id, e)
            raise

    ###########################
    #### RESOURCE CLEANUP #####
    ###########################

    async def arelease(self) -> Any | None:
        """
        Clean up connections and resources.

        Returns:
            Any | None: None
        """
        """Clean up connections and resources."""
        logger.info("Releasing PgCheckpointer resources")

        if not self.release_resources:
            logger.info("No resources to release")
            return

        errors = []

        # Close Redis connection
        try:
            if hasattr(self.redis, "aclose"):
                await self.redis.aclose()
            elif hasattr(self.redis, "close"):
                await self.redis.close()
            logger.debug("Redis connection closed")
        except Exception as e:
            logger.error("Error closing Redis connection: %s", e)
            errors.append(f"Redis: {e}")

        # Close PostgreSQL pool
        try:
            if self._pg_pool and not self._pg_pool.is_closing():
                await self._pg_pool.close()
            logger.debug("PostgreSQL pool closed")
        except Exception as e:
            logger.error("Error closing PostgreSQL pool: %s", e)
            errors.append(f"PostgreSQL: {e}")

        if errors:
            error_msg = f"Errors during resource cleanup: {'; '.join(errors)}"
            logger.warning(error_msg)
            # Don't raise - cleanup should be best effort
        else:
            logger.info("All resources released successfully")
Attributes
cache_ttl instance-attribute
cache_ttl = get('cache_ttl', DEFAULT_CACHE_TTL)
id_type instance-attribute
id_type = get('id_type', try_get('generated_id_type', 'string'))
redis instance-attribute
redis = _create_redis_pool(redis, redis_pool, redis_url, redis_pool_config or {})
release_resources instance-attribute
release_resources = get('release_resources', False)
schema instance-attribute
schema = schema
user_id_type instance-attribute
user_id_type = get('user_id_type', 'string')
Functions
__init__
__init__(postgres_dsn=None, pg_pool=None, pool_config=None, redis_url=None, redis=None, redis_pool=None, redis_pool_config=None, schema='public', **kwargs)

Initializes PgCheckpointer with PostgreSQL and Redis connections.

Parameters:

Name Type Description Default
postgres_dsn str

PostgreSQL connection string.

None
pg_pool Any

Existing asyncpg Pool instance.

None
pool_config dict

Configuration for new pg pool creation.

None
redis_url str

Redis connection URL.

None
redis Any

Existing Redis instance.

None
redis_pool Any

Existing Redis ConnectionPool.

None
redis_pool_config dict

Configuration for new redis pool creation.

None
schema str

PostgreSQL schema name. Defaults to "public".

'public'
**kwargs

Additional configuration options.

{}

Raises:

Type Description
ImportError

If required dependencies are missing.

ValueError

If required connection details are missing.

Source code in pyagenity/checkpointer/pg_checkpointer.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def __init__(
    self,
    # postgress connection details
    postgres_dsn: str | None = None,
    pg_pool: Any | None = None,
    pool_config: dict | None = None,
    # redis connection details
    redis_url: str | None = None,
    redis: Any | None = None,
    redis_pool: Any | None = None,
    redis_pool_config: dict | None = None,
    # database schema
    schema: str = "public",
    # other configurations - combine to reduce args
    **kwargs,
):
    """
    Initializes PgCheckpointer with PostgreSQL and Redis connections.

    Args:
        postgres_dsn (str, optional): PostgreSQL connection string.
        pg_pool (Any, optional): Existing asyncpg Pool instance.
        pool_config (dict, optional): Configuration for new pg pool creation.
        redis_url (str, optional): Redis connection URL.
        redis (Any, optional): Existing Redis instance.
        redis_pool (Any, optional): Existing Redis ConnectionPool.
        redis_pool_config (dict, optional): Configuration for new redis pool creation.
        schema (str, optional): PostgreSQL schema name. Defaults to "public".
        **kwargs: Additional configuration options.

    Raises:
        ImportError: If required dependencies are missing.
        ValueError: If required connection details are missing.
    """
    # Check for required dependencies
    if not HAS_ASYNCPG:
        raise ImportError(
            "PgCheckpointer requires 'asyncpg' package. "
            "Install with: pip install pyagenity[pg_checkpoint]"
        )

    if not HAS_REDIS:
        raise ImportError(
            "PgCheckpointer requires 'redis' package. "
            "Install with: pip install pyagenity[pg_checkpoint]"
        )

    self.user_id_type = kwargs.get("user_id_type", "string")
    # allow explicit override via kwargs, fallback to InjectQ, then default
    self.id_type = kwargs.get(
        "id_type", InjectQ.get_instance().try_get("generated_id_type", "string")
    )
    self.cache_ttl = kwargs.get("cache_ttl", DEFAULT_CACHE_TTL)
    self.release_resources = kwargs.get("release_resources", False)

    # Validate schema name to prevent SQL injection
    if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", schema):
        raise ValueError(
            f"Invalid schema name: {schema}. Schema must match pattern ^[a-zA-Z_][a-zA-Z0-9_]*$"
        )
    self.schema = schema

    self._schema_initialized = False
    self._loop: asyncio.AbstractEventLoop | None = None

    # Store pool configuration for lazy initialization
    self._pg_pool_config = {
        "pg_pool": pg_pool,
        "postgres_dsn": postgres_dsn,
        "pool_config": pool_config or {},
    }

    # Initialize pool immediately if provided, otherwise defer
    if pg_pool is not None:
        self._pg_pool = pg_pool
    else:
        self._pg_pool = None

    # Now check and initialize connections
    if not pg_pool and not postgres_dsn:
        raise ValueError("Either postgres_dsn or pg_pool must be provided.")

    if not redis and not redis_url and not redis_pool:
        raise ValueError("Either redis_url, redis_pool or redis instance must be provided.")

    # Initialize Redis connection (synchronous)
    self.redis = self._create_redis_pool(redis, redis_pool, redis_url, redis_pool_config or {})
aclean_thread async
aclean_thread(config)

Clean/delete a thread and all associated data.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
Any | None

Any | None: None

Raises:

Type Description
Exception

If cleaning fails.

Source code in pyagenity/checkpointer/pg_checkpointer.py
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
async def aclean_thread(self, config: dict[str, Any]) -> Any | None:
    """
    Clean/delete a thread and all associated data.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        Any | None: None

    Raises:
        Exception: If cleaning fails.
    """
    """Clean/delete a thread and all associated data."""
    # Ensure schema is initialized before accessing tables
    await self._initialize_schema()
    thread_id, user_id = self._validate_config(config)

    logger.debug("Cleaning thread thread_id=%s, user_id=%s", thread_id, user_id)

    try:
        # Delete thread (cascade will handle messages and states) with retry logic
        async def _clean_thread():
            async with (await self._get_pg_pool()).acquire() as conn:
                await conn.execute(
                    f"DELETE FROM {self._get_table_name('threads')} "  # noqa: S608
                    f"WHERE thread_id = $1 AND user_id = $2",
                    thread_id,
                    user_id,
                )

        await self._retry_on_connection_error(_clean_thread, max_retries=3)

        # Clean from Redis cache
        cache_key = self._get_thread_key(thread_id, user_id)
        await self.redis.delete(cache_key)

        logger.debug("Thread cleaned: thread_id=%s, user_id=%s", thread_id, user_id)

    except Exception as e:
        logger.error("Failed to clean thread thread_id=%s: %s", thread_id, e)
        raise
aclear_state async
aclear_state(config)

Clear state from PostgreSQL and Redis cache.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Name Type Description
Any Any

None

Raises:

Type Description
Exception

If clearing fails.

Source code in pyagenity/checkpointer/pg_checkpointer.py
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
async def aclear_state(self, config: dict[str, Any]) -> Any:
    """
    Clear state from PostgreSQL and Redis cache.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        Any: None

    Raises:
        Exception: If clearing fails.
    """
    """Clear state from PostgreSQL and Redis cache."""
    # Ensure schema is initialized before accessing tables
    await self._initialize_schema()
    thread_id, user_id = self._validate_config(config)

    logger.debug("Clearing state for thread_id=%s, user_id=%s", thread_id, user_id)

    try:
        # Clear from PostgreSQL with retry logic
        async def _clear_state():
            async with (await self._get_pg_pool()).acquire() as conn:
                await conn.execute(
                    f"DELETE FROM {self._get_table_name('states')} WHERE thread_id = $1",  # noqa: S608
                    thread_id,
                )

        await self._retry_on_connection_error(_clear_state, max_retries=3)

        # Clear from Redis cache
        cache_key = self._get_thread_key(thread_id, user_id)
        await self.redis.delete(cache_key)

        logger.debug("State cleared for thread_id=%s", thread_id)

    except Exception as e:
        logger.error("Failed to clear state for thread_id=%s: %s", thread_id, e)
        raise
adelete_message async
adelete_message(config, message_id)

Delete a message by ID.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
message_id str | int

Message identifier.

required

Returns:

Type Description
Any | None

Any | None: None

Raises:

Type Description
Exception

If deletion fails.

Source code in pyagenity/checkpointer/pg_checkpointer.py
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
async def adelete_message(
    self,
    config: dict[str, Any],
    message_id: str | int,
) -> Any | None:
    """
    Delete a message by ID.

    Args:
        config (dict): Configuration dictionary.
        message_id (str|int): Message identifier.

    Returns:
        Any | None: None

    Raises:
        Exception: If deletion fails.
    """
    # Ensure schema is initialized before accessing tables
    await self._initialize_schema()
    thread_id = config.get("thread_id")

    logger.debug("Deleting message_id=%s for thread_id=%s", message_id, thread_id)

    try:

        async def _delete_message():
            async with (await self._get_pg_pool()).acquire() as conn:
                if thread_id:
                    await conn.execute(
                        f"DELETE FROM {self._get_table_name('messages')} "  # noqa: S608
                        f"WHERE message_id = $1 AND thread_id = $2",
                        message_id,
                        thread_id,
                    )
                else:
                    await conn.execute(
                        f"DELETE FROM {self._get_table_name('messages')} WHERE message_id = $1",  # noqa: S608
                        message_id,
                    )

        await self._retry_on_connection_error(_delete_message, max_retries=3)
        logger.debug("Deleted message_id=%s", message_id)
        return None

    except Exception as e:
        logger.error("Failed to delete message_id=%s: %s", message_id, e)
        raise
aget_message async
aget_message(config, message_id)

Retrieve a single message by ID.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
message_id str | int

Message identifier.

required

Returns:

Name Type Description
Message Message

Retrieved message object.

Raises:

Type Description
Exception

If retrieval fails.

Source code in pyagenity/checkpointer/pg_checkpointer.py
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
async def aget_message(self, config: dict[str, Any], message_id: str | int) -> Message:
    """
    Retrieve a single message by ID.

    Args:
        config (dict): Configuration dictionary.
        message_id (str|int): Message identifier.

    Returns:
        Message: Retrieved message object.

    Raises:
        Exception: If retrieval fails.
    """
    """Retrieve a single message by ID."""
    # Ensure schema is initialized before accessing tables
    await self._initialize_schema()
    thread_id = config.get("thread_id")

    logger.debug("Retrieving message_id=%s for thread_id=%s", message_id, thread_id)

    try:

        async def _get_message():
            async with (await self._get_pg_pool()).acquire() as conn:
                query = f"""
                    SELECT message_id, thread_id, role, content, tool_calls,
                           tool_call_id, reasoning, created_at, total_tokens,
                           usages, meta
                    FROM {self._get_table_name("messages")}
                    WHERE message_id = $1
                """  # noqa: S608
                if thread_id:
                    query += " AND thread_id = $2"
                    return await conn.fetchrow(query, message_id, thread_id)
                return await conn.fetchrow(query, message_id)

        row = await self._retry_on_connection_error(_get_message, max_retries=3)

        if not row:
            raise ValueError(f"Message not found: {message_id}")

        return self._row_to_message(row)

    except Exception as e:
        logger.error("Failed to retrieve message_id=%s: %s", message_id, e)
        raise
aget_state async
aget_state(config)

Retrieve state from PostgreSQL.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
StateT | None

StateT | None: Retrieved state or None.

Raises:

Type Description
Exception

If retrieval fails.

Source code in pyagenity/checkpointer/pg_checkpointer.py
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
async def aget_state(self, config: dict[str, Any]) -> StateT | None:
    """
    Retrieve state from PostgreSQL.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        StateT | None: Retrieved state or None.

    Raises:
        Exception: If retrieval fails.
    """
    """Retrieve state from PostgreSQL."""
    # Ensure schema is initialized before accessing tables
    await self._initialize_schema()
    thread_id, user_id = self._validate_config(config)
    state_class = config.get("state_class", AgentState)

    logger.debug("Retrieving state for thread_id=%s, user_id=%s", thread_id, user_id)

    try:

        async def _get_state():
            async with (await self._get_pg_pool()).acquire() as conn:
                return await conn.fetchrow(
                    f"""
                    SELECT state_data FROM {self._get_table_name("states")}
                    WHERE thread_id = $1
                    ORDER BY created_at DESC
                    LIMIT 1
                    """,  # noqa: S608
                    thread_id,
                )

        row = await self._retry_on_connection_error(_get_state, max_retries=3)

        if row:
            logger.debug("State found for thread_id=%s", thread_id)
            return self._deserialize_state(row["state_data"], state_class)

        logger.debug("No state found for thread_id=%s", thread_id)
        return None

    except Exception as e:
        logger.error("Failed to retrieve state for thread_id=%s: %s", thread_id, e)
        raise
aget_state_cache async
aget_state_cache(config)

Get state from Redis cache, fallback to PostgreSQL if miss.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
StateT | None

StateT | None: State object or None.

Source code in pyagenity/checkpointer/pg_checkpointer.py
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
async def aget_state_cache(self, config: dict[str, Any]) -> StateT | None:
    """
    Get state from Redis cache, fallback to PostgreSQL if miss.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        StateT | None: State object or None.
    """
    """Get state from Redis cache, fallback to PostgreSQL if miss."""
    # Schema might be needed if we fall back to DB
    await self._initialize_schema()
    thread_id, user_id = self._validate_config(config)
    state_class = config.get("state_class", AgentState)

    logger.debug("Getting cached state for thread_id=%s, user_id=%s", thread_id, user_id)

    try:
        # Try Redis first
        cache_key = self._get_thread_key(thread_id, user_id)
        cached_data = await self.redis.get(cache_key)

        if cached_data:
            logger.debug("Cache hit for thread_id=%s", thread_id)
            return self._deserialize_state(cached_data.decode(), state_class)

        # Cache miss - fallback to PostgreSQL
        logger.debug("Cache miss for thread_id=%s, falling back to PostgreSQL", thread_id)
        state = await self.aget_state(config)

        # Cache the result for next time
        if state:
            await self.aput_state_cache(config, state)

        return state

    except Exception as e:
        logger.error("Failed to get cached state for thread_id=%s: %s", thread_id, e)
        # Fallback to PostgreSQL on error
        return await self.aget_state(config)
aget_thread async
aget_thread(config)

Get thread information.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
ThreadInfo | None

ThreadInfo | None: Thread information object or None.

Raises:

Type Description
Exception

If retrieval fails.

Source code in pyagenity/checkpointer/pg_checkpointer.py
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
async def aget_thread(
    self,
    config: dict[str, Any],
) -> ThreadInfo | None:
    """
    Get thread information.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        ThreadInfo | None: Thread information object or None.

    Raises:
        Exception: If retrieval fails.
    """
    # Ensure schema is initialized before accessing tables
    await self._initialize_schema()
    thread_id, user_id = self._validate_config(config)

    logger.debug("Retrieving thread info for thread_id=%s, user_id=%s", thread_id, user_id)

    try:

        async def _get_thread():
            async with (await self._get_pg_pool()).acquire() as conn:
                return await conn.fetchrow(
                    f"""
                    SELECT thread_id, thread_name, user_id, created_at, updated_at, meta
                    FROM {self._get_table_name("threads")}
                    WHERE thread_id = $1 AND user_id = $2
                    """,  # noqa: S608
                    thread_id,
                    user_id,
                )

        row = await self._retry_on_connection_error(_get_thread, max_retries=3)

        if row:
            meta_dict = {}
            if row["meta"]:
                meta_dict = (
                    json.loads(row["meta"]) if isinstance(row["meta"], str) else row["meta"]
                )
            return ThreadInfo(
                thread_id=thread_id,
                thread_name=row["thread_name"] if row else None,
                user_id=user_id,
                metadata=meta_dict,
                run_id=meta_dict.get("run_id"),
            )

        logger.debug("Thread not found for thread_id=%s, user_id=%s", thread_id, user_id)
        return None

    except Exception as e:
        logger.error("Failed to retrieve thread info for thread_id=%s: %s", thread_id, e)
        raise
alist_messages async
alist_messages(config, search=None, offset=None, limit=None)

List messages for a thread with optional search and pagination.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
search str

Search string.

None
offset int

Offset for pagination.

None
limit int

Limit for pagination.

None

Returns:

Type Description
list[Message]

list[Message]: List of message objects.

Raises:

Type Description
Exception

If listing fails.

Source code in pyagenity/checkpointer/pg_checkpointer.py
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
async def alist_messages(
    self,
    config: dict[str, Any],
    search: str | None = None,
    offset: int | None = None,
    limit: int | None = None,
) -> list[Message]:
    """
    List messages for a thread with optional search and pagination.

    Args:
        config (dict): Configuration dictionary.
        search (str, optional): Search string.
        offset (int, optional): Offset for pagination.
        limit (int, optional): Limit for pagination.

    Returns:
        list[Message]: List of message objects.

    Raises:
        Exception: If listing fails.
    """
    # Ensure schema is initialized before accessing tables
    await self._initialize_schema()
    thread_id = config.get("thread_id")

    if not thread_id:
        raise ValueError("thread_id must be provided in config")

    logger.debug("Listing messages for thread_id=%s", thread_id)

    try:

        async def _list_messages():
            async with (await self._get_pg_pool()).acquire() as conn:
                # Build query with optional search
                query = f"""
                    SELECT message_id, thread_id, role, content, tool_calls,
                           tool_call_id, reasoning, created_at, total_tokens,
                           usages, meta
                    FROM {self._get_table_name("messages")}
                    WHERE thread_id = $1
                """  # noqa: S608
                params = [thread_id]
                param_count = 1

                if search:
                    param_count += 1
                    query += f" AND content ILIKE ${param_count}"
                    params.append(f"%{search}%")

                query += " ORDER BY created_at ASC"

                if limit:
                    param_count += 1
                    query += f" LIMIT ${param_count}"
                    params.append(limit)

                if offset:
                    param_count += 1
                    query += f" OFFSET ${param_count}"
                    params.append(offset)

                return await conn.fetch(query, *params)

        rows = await self._retry_on_connection_error(_list_messages, max_retries=3)
        if not rows:
            rows = []
        messages = [self._row_to_message(row) for row in rows]

        logger.debug("Found %d messages for thread_id=%s", len(messages), thread_id)
        return messages

    except Exception as e:
        logger.error("Failed to list messages for thread_id=%s: %s", thread_id, e)
        raise
alist_threads async
alist_threads(config, search=None, offset=None, limit=None)

List threads for a user with optional search and pagination.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
search str

Search string.

None
offset int

Offset for pagination.

None
limit int

Limit for pagination.

None

Returns:

Type Description
list[ThreadInfo]

list[ThreadInfo]: List of thread information objects.

Raises:

Type Description
Exception

If listing fails.

Source code in pyagenity/checkpointer/pg_checkpointer.py
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
async def alist_threads(
    self,
    config: dict[str, Any],
    search: str | None = None,
    offset: int | None = None,
    limit: int | None = None,
) -> list[ThreadInfo]:
    """
    List threads for a user with optional search and pagination.

    Args:
        config (dict): Configuration dictionary.
        search (str, optional): Search string.
        offset (int, optional): Offset for pagination.
        limit (int, optional): Limit for pagination.

    Returns:
        list[ThreadInfo]: List of thread information objects.

    Raises:
        Exception: If listing fails.
    """
    # Ensure schema is initialized before accessing tables
    await self._initialize_schema()
    user_id = config.get("user_id")
    user_id = user_id or "test-user"

    if not user_id:
        raise ValueError("user_id must be provided in config")

    logger.debug("Listing threads for user_id=%s", user_id)

    try:

        async def _list_threads():
            async with (await self._get_pg_pool()).acquire() as conn:
                # Build query with optional search
                query = f"""
                    SELECT thread_id, thread_name, user_id, created_at, updated_at, meta
                    FROM {self._get_table_name("threads")}
                    WHERE user_id = $1
                """  # noqa: S608
                params = [user_id]
                param_count = 1

                if search:
                    param_count += 1
                    query += f" AND thread_name ILIKE ${param_count}"
                    params.append(f"%{search}%")

                query += " ORDER BY updated_at DESC"

                if limit:
                    param_count += 1
                    query += f" LIMIT ${param_count}"
                    params.append(limit)

                if offset:
                    param_count += 1
                    query += f" OFFSET ${param_count}"
                    params.append(offset)

                return await conn.fetch(query, *params)

        rows = await self._retry_on_connection_error(_list_threads, max_retries=3)
        if not rows:
            rows = []

        threads = []
        for row in rows:
            meta_dict = {}
            if row["meta"]:
                meta_dict = (
                    json.loads(row["meta"]) if isinstance(row["meta"], str) else row["meta"]
                )
            threads.append(
                ThreadInfo(
                    thread_id=row["thread_id"],
                    thread_name=row["thread_name"],
                    user_id=row["user_id"],
                    metadata=meta_dict,
                    run_id=meta_dict.get("run_id"),
                    updated_at=row["updated_at"],
                )
            )
        logger.debug("Found %d threads for user_id=%s", len(threads), user_id)
        return threads

    except Exception as e:
        logger.error("Failed to list threads for user_id=%s: %s", user_id, e)
        raise
aput_messages async
aput_messages(config, messages, metadata=None)

Store messages in PostgreSQL.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
messages list[Message]

List of messages to store.

required
metadata dict

Additional metadata.

None

Returns:

Name Type Description
Any Any

None

Raises:

Type Description
Exception

If storing fails.

Source code in pyagenity/checkpointer/pg_checkpointer.py
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
async def aput_messages(
    self,
    config: dict[str, Any],
    messages: list[Message],
    metadata: dict[str, Any] | None = None,
) -> Any:
    """
    Store messages in PostgreSQL.

    Args:
        config (dict): Configuration dictionary.
        messages (list[Message]): List of messages to store.
        metadata (dict, optional): Additional metadata.

    Returns:
        Any: None

    Raises:
        Exception: If storing fails.
    """
    # Ensure schema is initialized before accessing tables
    await self._initialize_schema()
    thread_id, user_id = self._validate_config(config)

    if not messages:
        logger.debug("No messages to store for thread_id=%s", thread_id)
        return

    logger.debug("Storing %d messages for thread_id=%s", len(messages), thread_id)

    try:
        # Ensure thread exists
        await self._ensure_thread_exists(thread_id, user_id, config)

        # Store messages in batch with retry logic
        async def _store_messages():
            async with (await self._get_pg_pool()).acquire() as conn, conn.transaction():
                for message in messages:
                    # content_value = message.content
                    # if not isinstance(content_value, str):
                    #     try:
                    #         content_value = json.dumps(content_value)
                    #     except Exception:
                    #         content_value = str(content_value)
                    await conn.execute(
                        f"""
                            INSERT INTO {self._get_table_name("messages")} (
                                message_id, thread_id, role, content, tool_calls,
                                tool_call_id, reasoning, total_tokens, usages, meta
                            )
                            VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
                            ON CONFLICT (message_id) DO UPDATE SET
                                content = EXCLUDED.content,
                                reasoning = EXCLUDED.reasoning,
                                usages = EXCLUDED.usages,
                                updated_at = NOW()
                            """,  # noqa: S608
                        message.message_id,
                        thread_id,
                        message.role,
                        json.dumps(
                            [block.model_dump(mode="json") for block in message.content]
                        ),
                        json.dumps(message.tools_calls) if message.tools_calls else None,
                        getattr(message, "tool_call_id", None),
                        message.reasoning,
                        message.usages.total_tokens if message.usages else 0,
                        json.dumps(message.usages.model_dump()) if message.usages else None,
                        json.dumps({**(metadata or {}), **(message.metadata or {})}),
                    )

        await self._retry_on_connection_error(_store_messages, max_retries=3)
        logger.debug("Stored %d messages for thread_id=%s", len(messages), thread_id)

    except Exception as e:
        logger.error("Failed to store messages for thread_id=%s: %s", thread_id, e)
        raise
aput_state async
aput_state(config, state)

Store state in PostgreSQL and optionally cache in Redis.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
state StateT

State object to store.

required

Returns:

Name Type Description
StateT StateT

The stored state object.

Raises:

Type Description
StorageError

If storing fails.

Source code in pyagenity/checkpointer/pg_checkpointer.py
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
async def aput_state(
    self,
    config: dict[str, Any],
    state: StateT,
) -> StateT:
    """
    Store state in PostgreSQL and optionally cache in Redis.

    Args:
        config (dict): Configuration dictionary.
        state (StateT): State object to store.

    Returns:
        StateT: The stored state object.

    Raises:
        StorageError: If storing fails.
    """
    # Ensure schema is initialized before accessing tables
    await self._initialize_schema()
    thread_id, user_id = self._validate_config(config)

    logger.debug("Storing state for thread_id=%s, user_id=%s", thread_id, user_id)
    metrics.counter("pg_checkpointer.save_state.attempts").inc()

    with metrics.timer("pg_checkpointer.save_state.duration"):
        try:
            # Ensure thread exists first
            await self._ensure_thread_exists(thread_id, user_id, config)

            # Store in PostgreSQL with retry logic
            state_json = self._serialize_state_fast(state)

            async def _store_state():
                async with (await self._get_pg_pool()).acquire() as conn:
                    await conn.execute(
                        f"""
                        INSERT INTO {self._get_table_name("states")}
                            (thread_id, state_data, meta)
                        VALUES ($1, $2, $3)
                        ON CONFLICT DO NOTHING
                        """,  # noqa: S608
                        thread_id,
                        state_json,
                        json.dumps(config.get("meta", {})),
                    )

            await self._retry_on_connection_error(_store_state, max_retries=3)
            logger.debug("State stored successfully for thread_id=%s", thread_id)
            metrics.counter("pg_checkpointer.save_state.success").inc()
            return state

        except Exception as e:
            metrics.counter("pg_checkpointer.save_state.error").inc()
            logger.error("Failed to store state for thread_id=%s: %s", thread_id, e)
            if asyncpg and hasattr(asyncpg, "ConnectionDoesNotExistError"):
                connection_errors = (
                    asyncpg.ConnectionDoesNotExistError,
                    asyncpg.InterfaceError,
                )
                if isinstance(e, connection_errors):
                    raise TransientStorageError(f"Connection issue storing state: {e}") from e
            raise StorageError(f"Failed to store state: {e}") from e
aput_state_cache async
aput_state_cache(config, state)

Cache state in Redis with TTL.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
state StateT

State object to cache.

required

Returns:

Type Description
Any | None

Any | None: True if cached, None if failed.

Source code in pyagenity/checkpointer/pg_checkpointer.py
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
async def aput_state_cache(self, config: dict[str, Any], state: StateT) -> Any | None:
    """
    Cache state in Redis with TTL.

    Args:
        config (dict): Configuration dictionary.
        state (StateT): State object to cache.

    Returns:
        Any | None: True if cached, None if failed.
    """
    """Cache state in Redis with TTL."""
    # No DB access, but keep consistent
    thread_id, user_id = self._validate_config(config)

    logger.debug("Caching state for thread_id=%s, user_id=%s", thread_id, user_id)

    try:
        cache_key = self._get_thread_key(thread_id, user_id)
        state_json = self._serialize_state(state)
        await self.redis.setex(cache_key, self.cache_ttl, state_json)
        logger.debug("State cached with key=%s, ttl=%d", cache_key, self.cache_ttl)
        return True

    except Exception as e:
        logger.error("Failed to cache state for thread_id=%s: %s", thread_id, e)
        # Don't raise - caching is optional
        return None
aput_thread async
aput_thread(config, thread_info)

Create or update thread information.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
thread_info ThreadInfo

Thread information object.

required

Returns:

Type Description
Any | None

Any | None: None

Raises:

Type Description
Exception

If storing fails.

Source code in pyagenity/checkpointer/pg_checkpointer.py
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
async def aput_thread(
    self,
    config: dict[str, Any],
    thread_info: ThreadInfo,
) -> Any | None:
    """
    Create or update thread information.

    Args:
        config (dict): Configuration dictionary.
        thread_info (ThreadInfo): Thread information object.

    Returns:
        Any | None: None

    Raises:
        Exception: If storing fails.
    """
    # Ensure schema is initialized before accessing tables
    await self._initialize_schema()
    thread_id, user_id = self._validate_config(config)

    logger.debug("Storing thread info for thread_id=%s, user_id=%s", thread_id, user_id)

    try:
        thread_name = thread_info.thread_name or f"Thread {thread_id}"
        meta = thread_info.metadata or {}
        user_id = thread_info.user_id or user_id
        meta.update(
            {
                "run_id": thread_info.run_id,
            }
        )

        async def _put_thread():
            async with (await self._get_pg_pool()).acquire() as conn:
                await conn.execute(
                    f"""
                    INSERT INTO {self._get_table_name("threads")}
                        (thread_id, thread_name, user_id, meta)
                    VALUES ($1, $2, $3, $4)
                    ON CONFLICT (thread_id) DO UPDATE SET
                        thread_name = EXCLUDED.thread_name,
                        meta = EXCLUDED.meta,
                        updated_at = NOW()
                    """,  # noqa: S608
                    thread_id,
                    thread_name,
                    user_id,
                    json.dumps(meta),
                )

        await self._retry_on_connection_error(_put_thread, max_retries=3)
        logger.debug("Thread info stored for thread_id=%s", thread_id)

    except Exception as e:
        logger.error("Failed to store thread info for thread_id=%s: %s", thread_id, e)
        raise
arelease async
arelease()

Clean up connections and resources.

Returns:

Type Description
Any | None

Any | None: None

Source code in pyagenity/checkpointer/pg_checkpointer.py
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
async def arelease(self) -> Any | None:
    """
    Clean up connections and resources.

    Returns:
        Any | None: None
    """
    """Clean up connections and resources."""
    logger.info("Releasing PgCheckpointer resources")

    if not self.release_resources:
        logger.info("No resources to release")
        return

    errors = []

    # Close Redis connection
    try:
        if hasattr(self.redis, "aclose"):
            await self.redis.aclose()
        elif hasattr(self.redis, "close"):
            await self.redis.close()
        logger.debug("Redis connection closed")
    except Exception as e:
        logger.error("Error closing Redis connection: %s", e)
        errors.append(f"Redis: {e}")

    # Close PostgreSQL pool
    try:
        if self._pg_pool and not self._pg_pool.is_closing():
            await self._pg_pool.close()
        logger.debug("PostgreSQL pool closed")
    except Exception as e:
        logger.error("Error closing PostgreSQL pool: %s", e)
        errors.append(f"PostgreSQL: {e}")

    if errors:
        error_msg = f"Errors during resource cleanup: {'; '.join(errors)}"
        logger.warning(error_msg)
        # Don't raise - cleanup should be best effort
    else:
        logger.info("All resources released successfully")
asetup async
asetup()

Asynchronous setup method. Initializes database schema.

Returns:

Name Type Description
Any Any

True if setup completed.

Source code in pyagenity/checkpointer/pg_checkpointer.py
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
async def asetup(self) -> Any:
    """
    Asynchronous setup method. Initializes database schema.

    Returns:
        Any: True if setup completed.
    """
    """Async setup method - initializes database schema."""
    logger.info(
        "Setting up PgCheckpointer (async)",
        extra={
            "id_type": self.id_type,
            "user_id_type": self.user_id_type,
            "schema": self.schema,
        },
    )
    await self._initialize_schema()
    logger.info("PgCheckpointer setup completed")
    return True
clean_thread
clean_thread(config)

Clean/delete thread synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
458
459
460
461
462
463
464
465
466
467
468
def clean_thread(self, config: dict[str, Any]) -> Any | None:
    """
    Clean/delete thread synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        Any | None: Implementation-defined result.
    """
    return run_coroutine(self.aclean_thread(config))
clear_state
clear_state(config)

Clear agent state synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Name Type Description
Any Any

Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
159
160
161
162
163
164
165
166
167
168
169
def clear_state(self, config: dict[str, Any]) -> Any:
    """
    Clear agent state synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        Any: Implementation-defined result.
    """
    return run_coroutine(self.aclear_state(config))
delete_message
delete_message(config, message_id)

Delete a specific message synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
message_id str | int

Message identifier.

required

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
324
325
326
327
328
329
330
331
332
333
334
335
def delete_message(self, config: dict[str, Any], message_id: str | int) -> Any | None:
    """
    Delete a specific message synchronously.

    Args:
        config (dict): Configuration dictionary.
        message_id (str|int): Message identifier.

    Returns:
        Any | None: Implementation-defined result.
    """
    return run_coroutine(self.adelete_message(config, message_id))
get_message
get_message(config, message_id)

Retrieve a specific message synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Name Type Description
Message Message

Retrieved message object.

Source code in pyagenity/checkpointer/base_checkpointer.py
291
292
293
294
295
296
297
298
299
300
301
def get_message(self, config: dict[str, Any], message_id: str | int) -> Message:
    """
    Retrieve a specific message synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        Message: Retrieved message object.
    """
    return run_coroutine(self.aget_message(config, message_id))
get_state
get_state(config)

Retrieve agent state synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
StateT | None

StateT | None: Retrieved state or None.

Source code in pyagenity/checkpointer/base_checkpointer.py
147
148
149
150
151
152
153
154
155
156
157
def get_state(self, config: dict[str, Any]) -> StateT | None:
    """
    Retrieve agent state synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        StateT | None: Retrieved state or None.
    """
    return run_coroutine(self.aget_state(config))
get_state_cache
get_state_cache(config)

Retrieve agent state from cache synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
StateT | None

StateT | None: Cached state or None.

Source code in pyagenity/checkpointer/base_checkpointer.py
184
185
186
187
188
189
190
191
192
193
194
def get_state_cache(self, config: dict[str, Any]) -> StateT | None:
    """
    Retrieve agent state from cache synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        StateT | None: Cached state or None.
    """
    return run_coroutine(self.aget_state_cache(config))
get_thread
get_thread(config)

Retrieve thread info synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required

Returns:

Type Description
ThreadInfo | None

ThreadInfo | None: Thread information object or None.

Source code in pyagenity/checkpointer/base_checkpointer.py
425
426
427
428
429
430
431
432
433
434
435
def get_thread(self, config: dict[str, Any]) -> ThreadInfo | None:
    """
    Retrieve thread info synchronously.

    Args:
        config (dict): Configuration dictionary.

    Returns:
        ThreadInfo | None: Thread information object or None.
    """
    return run_coroutine(self.aget_thread(config))
list_messages
list_messages(config, search=None, offset=None, limit=None)

List messages synchronously with optional filtering.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
search str

Search string.

None
offset int

Offset for pagination.

None
limit int

Limit for pagination.

None

Returns:

Type Description
list[Message]

list[Message]: List of message objects.

Source code in pyagenity/checkpointer/base_checkpointer.py
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
def list_messages(
    self,
    config: dict[str, Any],
    search: str | None = None,
    offset: int | None = None,
    limit: int | None = None,
) -> list[Message]:
    """
    List messages synchronously with optional filtering.

    Args:
        config (dict): Configuration dictionary.
        search (str, optional): Search string.
        offset (int, optional): Offset for pagination.
        limit (int, optional): Limit for pagination.

    Returns:
        list[Message]: List of message objects.
    """
    return run_coroutine(self.alist_messages(config, search, offset, limit))
list_threads
list_threads(config, search=None, offset=None, limit=None)

List threads synchronously with optional filtering.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
search str

Search string.

None
offset int

Offset for pagination.

None
limit int

Limit for pagination.

None

Returns:

Type Description
list[ThreadInfo]

list[ThreadInfo]: List of thread information objects.

Source code in pyagenity/checkpointer/base_checkpointer.py
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
def list_threads(
    self,
    config: dict[str, Any],
    search: str | None = None,
    offset: int | None = None,
    limit: int | None = None,
) -> list[ThreadInfo]:
    """
    List threads synchronously with optional filtering.

    Args:
        config (dict): Configuration dictionary.
        search (str, optional): Search string.
        offset (int, optional): Offset for pagination.
        limit (int, optional): Limit for pagination.

    Returns:
        list[ThreadInfo]: List of thread information objects.
    """
    return run_coroutine(self.alist_threads(config, search, offset, limit))
put_messages
put_messages(config, messages, metadata=None)

Store messages synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
messages list[Message]

List of messages to store.

required
metadata dict

Additional metadata.

None

Returns:

Name Type Description
Any Any

Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
def put_messages(
    self,
    config: dict[str, Any],
    messages: list[Message],
    metadata: dict[str, Any] | None = None,
) -> Any:
    """
    Store messages synchronously.

    Args:
        config (dict): Configuration dictionary.
        messages (list[Message]): List of messages to store.
        metadata (dict, optional): Additional metadata.

    Returns:
        Any: Implementation-defined result.
    """
    return run_coroutine(self.aput_messages(config, messages, metadata))
put_state
put_state(config, state)

Store agent state synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
state StateT

State object to store.

required

Returns:

Name Type Description
StateT StateT

The stored state object.

Source code in pyagenity/checkpointer/base_checkpointer.py
134
135
136
137
138
139
140
141
142
143
144
145
def put_state(self, config: dict[str, Any], state: StateT) -> StateT:
    """
    Store agent state synchronously.

    Args:
        config (dict): Configuration dictionary.
        state (StateT): State object to store.

    Returns:
        StateT: The stored state object.
    """
    return run_coroutine(self.aput_state(config, state))
put_state_cache
put_state_cache(config, state)

Store agent state in cache synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
state StateT

State object to cache.

required

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
171
172
173
174
175
176
177
178
179
180
181
182
def put_state_cache(self, config: dict[str, Any], state: StateT) -> Any | None:
    """
    Store agent state in cache synchronously.

    Args:
        config (dict): Configuration dictionary.
        state (StateT): State object to cache.

    Returns:
        Any | None: Implementation-defined result.
    """
    return run_coroutine(self.aput_state_cache(config, state))
put_thread
put_thread(config, thread_info)

Store thread info synchronously.

Parameters:

Name Type Description Default
config dict

Configuration dictionary.

required
thread_info ThreadInfo

Thread information object.

required

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
412
413
414
415
416
417
418
419
420
421
422
423
def put_thread(self, config: dict[str, Any], thread_info: ThreadInfo) -> Any | None:
    """
    Store thread info synchronously.

    Args:
        config (dict): Configuration dictionary.
        thread_info (ThreadInfo): Thread information object.

    Returns:
        Any | None: Implementation-defined result.
    """
    return run_coroutine(self.aput_thread(config, thread_info))
release
release()

Release resources synchronously.

Returns:

Type Description
Any | None

Any | None: Implementation-defined result.

Source code in pyagenity/checkpointer/base_checkpointer.py
473
474
475
476
477
478
479
480
def release(self) -> Any | None:
    """
    Release resources synchronously.

    Returns:
        Any | None: Implementation-defined result.
    """
    return run_coroutine(self.arelease())
setup
setup()

Synchronous setup method for checkpointer.

Returns:

Name Type Description
Any Any

Implementation-defined setup result.

Source code in pyagenity/checkpointer/base_checkpointer.py
42
43
44
45
46
47
48
49
def setup(self) -> Any:
    """
    Synchronous setup method for checkpointer.

    Returns:
        Any: Implementation-defined setup result.
    """
    return run_coroutine(self.asetup())
Modules

exceptions

Custom exception classes for graph operations in PyAgenity.

This package provides
  • GraphError: Base exception for graph-related errors.
  • NodeError: Exception for node-specific errors.
  • GraphRecursionError: Exception for recursion limit errors in graphs.

Modules:

Name Description
graph_error
node_error
recursion_error
storage_exceptions

Structured exception taxonomy for persistence & runtime layers.

Classes:

Name Description
GraphError

Base exception for graph-related errors.

GraphRecursionError

Exception raised when graph execution exceeds the recursion limit.

NodeError

Exception raised when a node encounters an error.

Attributes

__all__ module-attribute
__all__ = ['GraphError', 'GraphRecursionError', 'NodeError']

Classes

GraphError

Bases: Exception

Base exception for graph-related errors.

This exception is raised when an error related to graph operations occurs.

Example

from pyagenity.exceptions.graph_error import GraphError raise GraphError("Invalid graph structure")

Methods:

Name Description
__init__

Initializes a GraphError with the given message.

Source code in pyagenity/exceptions/graph_error.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class GraphError(Exception):
    """
    Base exception for graph-related errors.

    This exception is raised when an error related to graph operations occurs.

    Example:
        >>> from pyagenity.exceptions.graph_error import GraphError
        >>> raise GraphError("Invalid graph structure")
    """

    def __init__(self, message: str):
        """
        Initializes a GraphError with the given message.

        Args:
            message (str): Description of the error.
        """
        logger.error("GraphError raised: %s", message)
        super().__init__(message)
Functions
__init__
__init__(message)

Initializes a GraphError with the given message.

Parameters:

Name Type Description Default
message str

Description of the error.

required
Source code in pyagenity/exceptions/graph_error.py
18
19
20
21
22
23
24
25
26
def __init__(self, message: str):
    """
    Initializes a GraphError with the given message.

    Args:
        message (str): Description of the error.
    """
    logger.error("GraphError raised: %s", message)
    super().__init__(message)
GraphRecursionError

Bases: GraphError

Exception raised when graph execution exceeds the recursion limit.

This exception is used to indicate that a graph operation has recursed too deeply.

Example

from pyagenity.exceptions.recursion_error import GraphRecursionError raise GraphRecursionError("Recursion limit exceeded in graph execution")

Methods:

Name Description
__init__

Initializes a GraphRecursionError with the given message.

Source code in pyagenity/exceptions/recursion_error.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class GraphRecursionError(GraphError):
    """
    Exception raised when graph execution exceeds the recursion limit.

    This exception is used to indicate that a graph operation has recursed too deeply.

    Example:
        >>> from pyagenity.exceptions.recursion_error import GraphRecursionError
        >>> raise GraphRecursionError("Recursion limit exceeded in graph execution")
    """

    def __init__(self, message: str):
        """
        Initializes a GraphRecursionError with the given message.

        Args:
            message (str): Description of the recursion error.
        """
        logger.error("GraphRecursionError raised: %s", message)
        super().__init__(message)
Functions
__init__
__init__(message)

Initializes a GraphRecursionError with the given message.

Parameters:

Name Type Description Default
message str

Description of the recursion error.

required
Source code in pyagenity/exceptions/recursion_error.py
20
21
22
23
24
25
26
27
28
def __init__(self, message: str):
    """
    Initializes a GraphRecursionError with the given message.

    Args:
        message (str): Description of the recursion error.
    """
    logger.error("GraphRecursionError raised: %s", message)
    super().__init__(message)
NodeError

Bases: GraphError

Exception raised when a node encounters an error.

This exception is used for errors specific to nodes within a graph.

Example

from pyagenity.exceptions.node_error import NodeError raise NodeError("Node failed to execute")

Methods:

Name Description
__init__

Initializes a NodeError with the given message.

Source code in pyagenity/exceptions/node_error.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class NodeError(GraphError):
    """
    Exception raised when a node encounters an error.

    This exception is used for errors specific to nodes within a graph.

    Example:
        >>> from pyagenity.exceptions.node_error import NodeError
        >>> raise NodeError("Node failed to execute")
    """

    def __init__(self, message: str):
        """
        Initializes a NodeError with the given message.

        Args:
            message (str): Description of the node error.
        """
        logger.error("NodeError raised: %s", message)
        super().__init__(message)
Functions
__init__
__init__(message)

Initializes a NodeError with the given message.

Parameters:

Name Type Description Default
message str

Description of the node error.

required
Source code in pyagenity/exceptions/node_error.py
20
21
22
23
24
25
26
27
28
def __init__(self, message: str):
    """
    Initializes a NodeError with the given message.

    Args:
        message (str): Description of the node error.
    """
    logger.error("NodeError raised: %s", message)
    super().__init__(message)

Modules

graph_error

Classes:

Name Description
GraphError

Base exception for graph-related errors.

Attributes:

Name Type Description
logger
Attributes
logger module-attribute
logger = getLogger(__name__)
Classes
GraphError

Bases: Exception

Base exception for graph-related errors.

This exception is raised when an error related to graph operations occurs.

Example

from pyagenity.exceptions.graph_error import GraphError raise GraphError("Invalid graph structure")

Methods:

Name Description
__init__

Initializes a GraphError with the given message.

Source code in pyagenity/exceptions/graph_error.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class GraphError(Exception):
    """
    Base exception for graph-related errors.

    This exception is raised when an error related to graph operations occurs.

    Example:
        >>> from pyagenity.exceptions.graph_error import GraphError
        >>> raise GraphError("Invalid graph structure")
    """

    def __init__(self, message: str):
        """
        Initializes a GraphError with the given message.

        Args:
            message (str): Description of the error.
        """
        logger.error("GraphError raised: %s", message)
        super().__init__(message)
Functions
__init__
__init__(message)

Initializes a GraphError with the given message.

Parameters:

Name Type Description Default
message str

Description of the error.

required
Source code in pyagenity/exceptions/graph_error.py
18
19
20
21
22
23
24
25
26
def __init__(self, message: str):
    """
    Initializes a GraphError with the given message.

    Args:
        message (str): Description of the error.
    """
    logger.error("GraphError raised: %s", message)
    super().__init__(message)
node_error

Classes:

Name Description
NodeError

Exception raised when a node encounters an error.

Attributes:

Name Type Description
logger
Attributes
logger module-attribute
logger = getLogger(__name__)
Classes
NodeError

Bases: GraphError

Exception raised when a node encounters an error.

This exception is used for errors specific to nodes within a graph.

Example

from pyagenity.exceptions.node_error import NodeError raise NodeError("Node failed to execute")

Methods:

Name Description
__init__

Initializes a NodeError with the given message.

Source code in pyagenity/exceptions/node_error.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class NodeError(GraphError):
    """
    Exception raised when a node encounters an error.

    This exception is used for errors specific to nodes within a graph.

    Example:
        >>> from pyagenity.exceptions.node_error import NodeError
        >>> raise NodeError("Node failed to execute")
    """

    def __init__(self, message: str):
        """
        Initializes a NodeError with the given message.

        Args:
            message (str): Description of the node error.
        """
        logger.error("NodeError raised: %s", message)
        super().__init__(message)
Functions
__init__
__init__(message)

Initializes a NodeError with the given message.

Parameters:

Name Type Description Default
message str

Description of the node error.

required
Source code in pyagenity/exceptions/node_error.py
20
21
22
23
24
25
26
27
28
def __init__(self, message: str):
    """
    Initializes a NodeError with the given message.

    Args:
        message (str): Description of the node error.
    """
    logger.error("NodeError raised: %s", message)
    super().__init__(message)
recursion_error

Classes:

Name Description
GraphRecursionError

Exception raised when graph execution exceeds the recursion limit.

Attributes:

Name Type Description
logger
Attributes
logger module-attribute
logger = getLogger(__name__)
Classes
GraphRecursionError

Bases: GraphError

Exception raised when graph execution exceeds the recursion limit.

This exception is used to indicate that a graph operation has recursed too deeply.

Example

from pyagenity.exceptions.recursion_error import GraphRecursionError raise GraphRecursionError("Recursion limit exceeded in graph execution")

Methods:

Name Description
__init__

Initializes a GraphRecursionError with the given message.

Source code in pyagenity/exceptions/recursion_error.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class GraphRecursionError(GraphError):
    """
    Exception raised when graph execution exceeds the recursion limit.

    This exception is used to indicate that a graph operation has recursed too deeply.

    Example:
        >>> from pyagenity.exceptions.recursion_error import GraphRecursionError
        >>> raise GraphRecursionError("Recursion limit exceeded in graph execution")
    """

    def __init__(self, message: str):
        """
        Initializes a GraphRecursionError with the given message.

        Args:
            message (str): Description of the recursion error.
        """
        logger.error("GraphRecursionError raised: %s", message)
        super().__init__(message)
Functions
__init__
__init__(message)

Initializes a GraphRecursionError with the given message.

Parameters:

Name Type Description Default
message str

Description of the recursion error.

required
Source code in pyagenity/exceptions/recursion_error.py
20
21
22
23
24
25
26
27
28
def __init__(self, message: str):
    """
    Initializes a GraphRecursionError with the given message.

    Args:
        message (str): Description of the recursion error.
    """
    logger.error("GraphRecursionError raised: %s", message)
    super().__init__(message)
storage_exceptions

Structured exception taxonomy for persistence & runtime layers.

These exceptions let higher-level orchestration decide retry / fail-fast logic instead of relying on broad except Exception blocks.

Classes:

Name Description
MetricsError

Raised when metrics emission fails (should normally be swallowed/logged).

SchemaVersionError

Raised when schema version detection or migration application fails.

SerializationError

Raised when (de)serialization of state/messages fails deterministically.

StorageError

Base class for non-retryable storage layer errors.

TransientStorageError

Retryable storage error (connection drops, timeouts).

Classes
MetricsError

Bases: Exception

Raised when metrics emission fails (should normally be swallowed/logged).

Source code in pyagenity/exceptions/storage_exceptions.py
26
27
class MetricsError(Exception):
    """Raised when metrics emission fails (should normally be swallowed/logged)."""
SchemaVersionError

Bases: StorageError

Raised when schema version detection or migration application fails.

Source code in pyagenity/exceptions/storage_exceptions.py
22
23
class SchemaVersionError(StorageError):
    """Raised when schema version detection or migration application fails."""
SerializationError

Bases: StorageError

Raised when (de)serialization of state/messages fails deterministically.

Source code in pyagenity/exceptions/storage_exceptions.py
18
19
class SerializationError(StorageError):
    """Raised when (de)serialization of state/messages fails deterministically."""
StorageError

Bases: Exception

Base class for non-retryable storage layer errors.

Source code in pyagenity/exceptions/storage_exceptions.py
10
11
class StorageError(Exception):
    """Base class for non-retryable storage layer errors."""
TransientStorageError

Bases: StorageError

Retryable storage error (connection drops, timeouts).

Source code in pyagenity/exceptions/storage_exceptions.py
14
15
class TransientStorageError(StorageError):
    """Retryable storage error (connection drops, timeouts)."""

graph

PyAgenity Graph Module - Core Workflow Engine.

This module provides the foundational components for building and executing agent workflows in PyAgenity. It implements a graph-based execution model similar to LangGraph, where workflows are defined as directed graphs of interconnected nodes that process state and execute business logic.

Architecture Overview:

The graph module follows a builder pattern for workflow construction and provides a compiled execution environment for runtime performance. The core components work together to enable complex, stateful agent interactions:

  1. StateGraph: The primary builder class for constructing workflows
  2. Node: Executable units that encapsulate functions or tool operations
  3. Edge: Connections between nodes that define execution flow
  4. CompiledGraph: The executable runtime form of a constructed graph
  5. ToolNode: Specialized node for managing and executing tools

Core Components:

StateGraph

The main entry point for building workflows. Provides a fluent API for adding nodes, connecting them with edges, and configuring execution behavior. Supports both static and conditional routing between nodes.

Node

Represents an executable unit within the graph. Wraps functions or ToolNode instances and handles dependency injection, parameter mapping, and execution context. Supports both regular and streaming execution modes.

Edge

Defines connections between nodes, supporting both static (always followed) and conditional (state-dependent) routing. Enables complex branching logic and decision trees within workflows.

CompiledGraph

The executable runtime form created by compiling a StateGraph. Provides synchronous and asynchronous execution methods, state persistence, event publishing, and comprehensive error handling.

ToolNode

A specialized registry and executor for callable functions from various sources including local functions, MCP tools, Composio integrations, and LangChain tools. Supports automatic schema generation and unified tool execution.

Key Features:

  • State Management: Persistent, typed state that flows between nodes
  • Dependency Injection: Automatic injection of framework services
  • Event Publishing: Comprehensive execution monitoring and debugging
  • Streaming Support: Real-time incremental result processing
  • Interrupts & Resume: Pauseable execution with checkpointing
  • Tool Integration: Unified interface for various tool providers
  • Type Safety: Generic typing for custom state classes
  • Error Handling: Robust error recovery and callback mechanisms

Usage Example:

```python
from pyagenity.graph import StateGraph, ToolNode
from pyagenity.utils import START, END


# Define workflow functions
def process_input(state, config):
    # Process user input
    result = analyze_input(state.context[-1].content)
    return [Message.text_message(f"Analysis: {result}")]


def generate_response(state, config):
    # Generate final response
    response = create_response(state.context)
    return [Message.text_message(response)]


# Create tools
def search_tool(query: str) -> str:
    return f"Search results for: {query}"


tools = ToolNode([search_tool])

# Build the graph
graph = StateGraph()
graph.add_node("process", process_input)
graph.add_node("search", tools)
graph.add_node("respond", generate_response)

# Define flow
graph.add_edge(START, "process")
graph.add_edge("process", "search")
graph.add_edge("search", "respond")
graph.add_edge("respond", END)

# Compile and execute
compiled = graph.compile()
result = compiled.invoke({"messages": [Message.text_message("Hello, world!")]})

# Cleanup
await compiled.aclose()
```

Integration Points:

The graph module integrates with other PyAgenity components:

  • State Module: Provides AgentState and context management
  • Utils Module: Supplies constants, messages, and helper functions
  • Checkpointer Module: Enables state persistence and recovery
  • Publisher Module: Handles event publishing and monitoring
  • Adapters Module: Connects with external tools and services

This architecture provides a flexible, extensible foundation for building sophisticated agent workflows while maintaining simplicity for common use cases.

Modules:

Name Description
compiled_graph
edge

Graph edge representation and routing logic for PyAgenity workflows.

node

Node execution and management for PyAgenity graph workflows.

state_graph
tool_node

ToolNode package.

utils

Classes:

Name Description
CompiledGraph

A fully compiled and executable graph ready for workflow execution.

Edge

Represents a connection between two nodes in a graph workflow.

Node

Represents a node in the graph workflow.

StateGraph

Main graph class for orchestrating multi-agent workflows.

ToolNode

A unified registry and executor for callable functions from various tool providers.

Attributes

__all__ module-attribute
__all__ = ['CompiledGraph', 'Edge', 'Node', 'StateGraph', 'ToolNode']

Classes

CompiledGraph

A fully compiled and executable graph ready for workflow execution.

CompiledGraph represents the final executable form of a StateGraph after compilation. It encapsulates all the execution logic, handlers, and services needed to run agent workflows. The graph supports both synchronous and asynchronous execution with comprehensive state management, checkpointing, event publishing, and streaming capabilities.

This class is generic over state types to support custom AgentState subclasses, ensuring type safety throughout the execution process.

Key Features: - Synchronous and asynchronous execution methods - Real-time streaming with incremental results - State persistence and checkpointing - Interrupt and resume capabilities - Event publishing for monitoring and debugging - Background task management - Graceful error handling and recovery

Attributes:

Name Type Description
_state

The initial/template state for graph executions.

_invoke_handler InvokeHandler[StateT]

Handler for non-streaming graph execution.

_stream_handler StreamHandler[StateT]

Handler for streaming graph execution.

_checkpointer BaseCheckpointer[StateT] | None

Optional state persistence backend.

_publisher BasePublisher | None

Optional event publishing backend.

_store BaseStore | None

Optional data storage backend.

_state_graph StateGraph[StateT]

Reference to the source StateGraph.

_interrupt_before list[str]

Nodes where execution should pause before execution.

_interrupt_after list[str]

Nodes where execution should pause after execution.

_task_manager

Manager for background async tasks.

Example
# After building and compiling a StateGraph
compiled = graph.compile()

# Synchronous execution
result = compiled.invoke({"messages": [Message.text_message("Hello")]})

# Asynchronous execution with streaming
async for chunk in compiled.astream({"messages": [message]}):
    print(f"Streamed: {chunk.content}")

# Graceful cleanup
await compiled.aclose()
Note

CompiledGraph instances should be properly closed using aclose() to release resources like database connections, background tasks, and event publishers.

Methods:

Name Description
__init__
aclose

Close the graph and release any resources.

ainvoke

Execute the graph asynchronously.

astop

Request the current graph execution to stop (async).

astream

Execute the graph asynchronously with streaming support.

generate_graph

Generate the graph representation.

invoke

Execute the graph synchronously and return the final results.

stop

Request the current graph execution to stop (sync helper).

stream

Execute the graph synchronously with streaming support.

Source code in pyagenity/graph/compiled_graph.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
class CompiledGraph[StateT: AgentState]:
    """A fully compiled and executable graph ready for workflow execution.

    CompiledGraph represents the final executable form of a StateGraph after compilation.
    It encapsulates all the execution logic, handlers, and services needed to run
    agent workflows. The graph supports both synchronous and asynchronous execution
    with comprehensive state management, checkpointing, event publishing, and
    streaming capabilities.

    This class is generic over state types to support custom AgentState subclasses,
    ensuring type safety throughout the execution process.

    Key Features:
    - Synchronous and asynchronous execution methods
    - Real-time streaming with incremental results
    - State persistence and checkpointing
    - Interrupt and resume capabilities
    - Event publishing for monitoring and debugging
    - Background task management
    - Graceful error handling and recovery

    Attributes:
        _state: The initial/template state for graph executions.
        _invoke_handler: Handler for non-streaming graph execution.
        _stream_handler: Handler for streaming graph execution.
        _checkpointer: Optional state persistence backend.
        _publisher: Optional event publishing backend.
        _store: Optional data storage backend.
        _state_graph: Reference to the source StateGraph.
        _interrupt_before: Nodes where execution should pause before execution.
        _interrupt_after: Nodes where execution should pause after execution.
        _task_manager: Manager for background async tasks.

    Example:
        ```python
        # After building and compiling a StateGraph
        compiled = graph.compile()

        # Synchronous execution
        result = compiled.invoke({"messages": [Message.text_message("Hello")]})

        # Asynchronous execution with streaming
        async for chunk in compiled.astream({"messages": [message]}):
            print(f"Streamed: {chunk.content}")

        # Graceful cleanup
        await compiled.aclose()
        ```

    Note:
        CompiledGraph instances should be properly closed using aclose() to
        release resources like database connections, background tasks, and
        event publishers.
    """

    def __init__(
        self,
        state: StateT,
        checkpointer: BaseCheckpointer[StateT] | None,
        publisher: BasePublisher | None,
        store: BaseStore | None,
        state_graph: StateGraph[StateT],
        interrupt_before: list[str],
        interrupt_after: list[str],
        task_manager: BackgroundTaskManager,
    ):
        logger.info(
            f"Initializing CompiledGraph with nodes: {list(state_graph.nodes.keys())}",
        )

        # Save initial state
        self._state = state

        # create handlers
        self._invoke_handler: InvokeHandler[StateT] = InvokeHandler[StateT](
            nodes=state_graph.nodes,  # type: ignore
            edges=state_graph.edges,  # type: ignore
        )
        self._stream_handler: StreamHandler[StateT] = StreamHandler[StateT](
            nodes=state_graph.nodes,  # type: ignore
            edges=state_graph.edges,  # type: ignore
        )

        self._checkpointer: BaseCheckpointer[StateT] | None = checkpointer
        self._publisher: BasePublisher | None = publisher
        self._store: BaseStore | None = store
        self._state_graph: StateGraph[StateT] = state_graph
        self._interrupt_before: list[str] = interrupt_before
        self._interrupt_after: list[str] = interrupt_after
        # generate task manager
        self._task_manager = task_manager

    def _prepare_config(
        self,
        config: dict[str, Any] | None,
        is_stream: bool = False,
    ) -> dict[str, Any]:
        cfg = config or {}
        if "is_stream" not in cfg:
            cfg["is_stream"] = is_stream
        if "user_id" not in cfg:
            cfg["user_id"] = "test-user-id"  # mock user id
        if "run_id" not in cfg:
            cfg["run_id"] = InjectQ.get_instance().try_get("generated_id") or str(uuid4())

        if "timestamp" not in cfg:
            cfg["timestamp"] = datetime.datetime.now().isoformat()

        return cfg

    def invoke(
        self,
        input_data: dict[str, Any],
        config: dict[str, Any] | None = None,
        response_granularity: ResponseGranularity = ResponseGranularity.LOW,
    ) -> dict[str, Any]:
        """Execute the graph synchronously and return the final results.

        Runs the complete graph workflow from start to finish, handling state
        management, node execution, and result formatting. This method automatically
        detects whether to start a fresh execution or resume from an interrupted state.

        The execution is synchronous but internally uses async operations, making it
        suitable for use in non-async contexts while still benefiting from async
        capabilities for I/O operations.

        Args:
            input_data: Input dictionary for graph execution. For new executions,
                should contain 'messages' key with list of initial messages.
                For resumed executions, can contain additional data to merge.
            config: Optional configuration dictionary containing execution settings:
                - user_id: Identifier for the user/session
                - thread_id: Unique identifier for this execution thread
                - run_id: Unique identifier for this specific run
                - recursion_limit: Maximum steps before stopping (default: 25)
            response_granularity: Level of detail in the response:
                - LOW: Returns only messages (default)
                - PARTIAL: Returns context, summary, and messages
                - FULL: Returns complete state and messages

        Returns:
            Dictionary containing execution results formatted according to the
            specified granularity level. Always includes execution messages
            and may include additional state information.

        Raises:
            ValueError: If input_data is invalid for new execution.
            GraphRecursionError: If execution exceeds recursion limit.
            Various exceptions: Depending on node execution failures.

        Example:
            ```python
            # Basic execution
            result = compiled.invoke({"messages": [Message.text_message("Process this data")]})
            print(result["messages"])  # Final execution messages

            # With configuration and full details
            result = compiled.invoke(
                input_data={"messages": [message]},
                config={"user_id": "user123", "thread_id": "session456", "recursion_limit": 50},
                response_granularity=ResponseGranularity.FULL,
            )
            print(result["state"])  # Complete final state
            ```

        Note:
            This method uses asyncio.run() internally, so it should not be called
            from within an async context. Use ainvoke() instead for async execution.
        """
        logger.info(
            "Starting synchronous graph execution with %d input keys, granularity=%s",
            len(input_data) if input_data else 0,
            response_granularity,
        )
        logger.debug("Input data keys: %s", list(input_data.keys()) if input_data else [])
        # Async Will Handle Event Publish

        try:
            result = asyncio.run(self.ainvoke(input_data, config, response_granularity))
            logger.info("Synchronous graph execution completed successfully")
            return result
        except Exception as e:
            logger.exception("Synchronous graph execution failed: %s", e)
            raise

    async def ainvoke(
        self,
        input_data: dict[str, Any],
        config: dict[str, Any] | None = None,
        response_granularity: ResponseGranularity = ResponseGranularity.LOW,
    ) -> dict[str, Any]:
        """Execute the graph asynchronously.

        Auto-detects whether to start fresh execution or resume from interrupted state
        based on the AgentState's execution metadata.

        Args:
            input_data: Input dict with 'messages' key (for new execution) or
                       additional data for resuming
            config: Configuration dictionary
            response_granularity: Response parsing granularity

        Returns:
            Response dict based on granularity
        """
        cfg = self._prepare_config(config, is_stream=False)

        return await self._invoke_handler.invoke(
            input_data,
            cfg,
            self._state,
            response_granularity,
        )

    def stop(self, config: dict[str, Any]) -> dict[str, Any]:
        """Request the current graph execution to stop (sync helper).

        This sets a stop flag in the checkpointer's thread store keyed by thread_id.
        Handlers periodically check this flag and interrupt execution.
        Returns a small status dict.
        """
        return asyncio.run(self.astop(config))

    async def astop(self, config: dict[str, Any]) -> dict[str, Any]:
        """Request the current graph execution to stop (async).

        Contract:
        - Requires a valid thread_id in config
        - If no active thread or no checkpointer, returns not-running
        - If state exists and is running, set stop_requested flag in thread info
        """
        cfg = self._prepare_config(config, is_stream=bool(config.get("is_stream", False)))
        if not self._checkpointer:
            return {"ok": False, "reason": "no-checkpointer"}

        # Load state to see if this thread is running
        state = await self._checkpointer.aget_state_cache(
            cfg
        ) or await self._checkpointer.aget_state(cfg)
        if not state:
            return {"ok": False, "running": False, "reason": "no-state"}

        running = state.is_running() and not state.is_interrupted()
        # Set stop flag regardless; handlers will act if running
        if running:
            state.execution_meta.stop_current_execution = StopRequestStatus.STOP_REQUESTED
            # update cache
            # Cache update is enough; state will be picked up by running execution
            # As its running, cache will be available immediately
            await self._checkpointer.aput_state_cache(cfg, state)
            # Fixme: consider putting to main state as well
            # await self._checkpointer.aput_state(cfg, state)
            logger.info("Set stop_current_execution flag for thread_id: %s", cfg.get("thread_id"))
            return {"ok": True, "running": running}

        logger.info(
            "No running execution to stop for thread_id: %s (running=%s, interrupted=%s)",
            cfg.get("thread_id"),
            running,
            state.is_interrupted(),
        )
        return {"ok": True, "running": running, "reason": "not-running"}

    def stream(
        self,
        input_data: dict[str, Any],
        config: dict[str, Any] | None = None,
        response_granularity: ResponseGranularity = ResponseGranularity.LOW,
    ) -> Generator[Message]:
        """Execute the graph synchronously with streaming support.

        Yields Message objects containing incremental responses.
        If nodes return streaming responses, yields them directly.
        If nodes return complete responses, simulates streaming by chunking.

        Args:
            input_data: Input dict
            config: Configuration dictionary
            response_granularity: Response parsing granularity

        Yields:
            Message objects with incremental content
        """

        # For sync streaming, we'll use asyncio.run to handle the async implementation
        async def _async_stream():
            async for chunk in self.astream(input_data, config, response_granularity):
                yield chunk

        # Convert async generator to sync iteration with a dedicated event loop
        gen = _async_stream()
        loop = asyncio.new_event_loop()
        policy = asyncio.get_event_loop_policy()
        try:
            previous_loop = policy.get_event_loop()
        except Exception:
            previous_loop = None
        asyncio.set_event_loop(loop)
        logger.info("Synchronous streaming started")

        try:
            while True:
                try:
                    chunk = loop.run_until_complete(gen.__anext__())
                    yield chunk
                except StopAsyncIteration:
                    break
        finally:
            # Attempt to close the async generator cleanly
            with contextlib.suppress(Exception):
                loop.run_until_complete(gen.aclose())  # type: ignore[attr-defined]
            # Restore previous loop if any, then close created loop
            try:
                if previous_loop is not None:
                    asyncio.set_event_loop(previous_loop)
            finally:
                loop.close()
        logger.info("Synchronous streaming completed")

    async def astream(
        self,
        input_data: dict[str, Any],
        config: dict[str, Any] | None = None,
        response_granularity: ResponseGranularity = ResponseGranularity.LOW,
    ) -> AsyncIterator[Message]:
        """Execute the graph asynchronously with streaming support.

        Yields Message objects containing incremental responses.
        If nodes return streaming responses, yields them directly.
        If nodes return complete responses, simulates streaming by chunking.

        Args:
            input_data: Input dict
            config: Configuration dictionary
            response_granularity: Response parsing granularity

        Yields:
            Message objects with incremental content
        """

        cfg = self._prepare_config(config, is_stream=True)

        async for chunk in self._stream_handler.stream(
            input_data,
            cfg,
            self._state,
            response_granularity,
        ):
            yield chunk

    async def aclose(self) -> dict[str, str]:
        """Close the graph and release any resources."""
        # close checkpointer
        stats = {}
        try:
            if self._checkpointer:
                await self._checkpointer.arelease()
                logger.info("Checkpointer closed successfully")
                stats["checkpointer"] = "closed"
        except Exception as e:
            stats["checkpointer"] = f"error: {e}"
            logger.error(f"Error closing graph: {e}")

        # Close Publisher
        try:
            if self._publisher:
                await self._publisher.close()
                logger.info("Publisher closed successfully")
                stats["publisher"] = "closed"
        except Exception as e:
            stats["publisher"] = f"error: {e}"
            logger.error(f"Error closing publisher: {e}")

        # Close Store
        try:
            if self._store:
                await self._store.arelease()
                logger.info("Store closed successfully")
                stats["store"] = "closed"
        except Exception as e:
            stats["store"] = f"error: {e}"
            logger.error(f"Error closing store: {e}")

        # Wait for all background tasks to complete
        try:
            await self._task_manager.wait_for_all()
            logger.info("All background tasks completed successfully")
            stats["background_tasks"] = "completed"
        except Exception as e:
            stats["background_tasks"] = f"error: {e}"
            logger.error(f"Error waiting for background tasks: {e}")

        logger.info(f"Graph close stats: {stats}")
        # You can also return or process the stats as needed
        return stats

    def generate_graph(self) -> dict[str, Any]:
        """Generate the graph representation.

        Returns:
            A dictionary representing the graph structure.
        """
        graph = {
            "info": {},
            "nodes": [],
            "edges": [],
        }
        # Populate the graph with nodes and edges
        for node_name in self._state_graph.nodes:
            graph["nodes"].append(
                {
                    "id": str(uuid4()),
                    "name": node_name,
                }
            )

        for edge in self._state_graph.edges:
            graph["edges"].append(
                {
                    "id": str(uuid4()),
                    "source": edge.from_node,
                    "target": edge.to_node,
                }
            )

        # Add few more extra info
        graph["info"] = {
            "node_count": len(graph["nodes"]),
            "edge_count": len(graph["edges"]),
            "checkpointer": self._checkpointer is not None,
            "checkpointer_type": type(self._checkpointer).__name__ if self._checkpointer else None,
            "publisher": self._publisher is not None,
            "store": self._store is not None,
            "interrupt_before": self._interrupt_before,
            "interrupt_after": self._interrupt_after,
            "context_type": self._state_graph._context_manager.__class__.__name__,
            "id_generator": self._state_graph._id_generator.__class__.__name__,
            "id_type": self._state_graph._id_generator.id_type.value,
            "state_type": self._state.__class__.__name__,
            "state_fields": list(self._state.model_dump().keys()),
        }
        return graph
Functions
__init__
__init__(state, checkpointer, publisher, store, state_graph, interrupt_before, interrupt_after, task_manager)
Source code in pyagenity/graph/compiled_graph.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def __init__(
    self,
    state: StateT,
    checkpointer: BaseCheckpointer[StateT] | None,
    publisher: BasePublisher | None,
    store: BaseStore | None,
    state_graph: StateGraph[StateT],
    interrupt_before: list[str],
    interrupt_after: list[str],
    task_manager: BackgroundTaskManager,
):
    logger.info(
        f"Initializing CompiledGraph with nodes: {list(state_graph.nodes.keys())}",
    )

    # Save initial state
    self._state = state

    # create handlers
    self._invoke_handler: InvokeHandler[StateT] = InvokeHandler[StateT](
        nodes=state_graph.nodes,  # type: ignore
        edges=state_graph.edges,  # type: ignore
    )
    self._stream_handler: StreamHandler[StateT] = StreamHandler[StateT](
        nodes=state_graph.nodes,  # type: ignore
        edges=state_graph.edges,  # type: ignore
    )

    self._checkpointer: BaseCheckpointer[StateT] | None = checkpointer
    self._publisher: BasePublisher | None = publisher
    self._store: BaseStore | None = store
    self._state_graph: StateGraph[StateT] = state_graph
    self._interrupt_before: list[str] = interrupt_before
    self._interrupt_after: list[str] = interrupt_after
    # generate task manager
    self._task_manager = task_manager
aclose async
aclose()

Close the graph and release any resources.

Source code in pyagenity/graph/compiled_graph.py
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
async def aclose(self) -> dict[str, str]:
    """Close the graph and release any resources."""
    # close checkpointer
    stats = {}
    try:
        if self._checkpointer:
            await self._checkpointer.arelease()
            logger.info("Checkpointer closed successfully")
            stats["checkpointer"] = "closed"
    except Exception as e:
        stats["checkpointer"] = f"error: {e}"
        logger.error(f"Error closing graph: {e}")

    # Close Publisher
    try:
        if self._publisher:
            await self._publisher.close()
            logger.info("Publisher closed successfully")
            stats["publisher"] = "closed"
    except Exception as e:
        stats["publisher"] = f"error: {e}"
        logger.error(f"Error closing publisher: {e}")

    # Close Store
    try:
        if self._store:
            await self._store.arelease()
            logger.info("Store closed successfully")
            stats["store"] = "closed"
    except Exception as e:
        stats["store"] = f"error: {e}"
        logger.error(f"Error closing store: {e}")

    # Wait for all background tasks to complete
    try:
        await self._task_manager.wait_for_all()
        logger.info("All background tasks completed successfully")
        stats["background_tasks"] = "completed"
    except Exception as e:
        stats["background_tasks"] = f"error: {e}"
        logger.error(f"Error waiting for background tasks: {e}")

    logger.info(f"Graph close stats: {stats}")
    # You can also return or process the stats as needed
    return stats
ainvoke async
ainvoke(input_data, config=None, response_granularity=ResponseGranularity.LOW)

Execute the graph asynchronously.

Auto-detects whether to start fresh execution or resume from interrupted state based on the AgentState's execution metadata.

Parameters:

Name Type Description Default
input_data dict[str, Any]

Input dict with 'messages' key (for new execution) or additional data for resuming

required
config dict[str, Any] | None

Configuration dictionary

None
response_granularity ResponseGranularity

Response parsing granularity

LOW

Returns:

Type Description
dict[str, Any]

Response dict based on granularity

Source code in pyagenity/graph/compiled_graph.py
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
async def ainvoke(
    self,
    input_data: dict[str, Any],
    config: dict[str, Any] | None = None,
    response_granularity: ResponseGranularity = ResponseGranularity.LOW,
) -> dict[str, Any]:
    """Execute the graph asynchronously.

    Auto-detects whether to start fresh execution or resume from interrupted state
    based on the AgentState's execution metadata.

    Args:
        input_data: Input dict with 'messages' key (for new execution) or
                   additional data for resuming
        config: Configuration dictionary
        response_granularity: Response parsing granularity

    Returns:
        Response dict based on granularity
    """
    cfg = self._prepare_config(config, is_stream=False)

    return await self._invoke_handler.invoke(
        input_data,
        cfg,
        self._state,
        response_granularity,
    )
astop async
astop(config)

Request the current graph execution to stop (async).

Contract: - Requires a valid thread_id in config - If no active thread or no checkpointer, returns not-running - If state exists and is running, set stop_requested flag in thread info

Source code in pyagenity/graph/compiled_graph.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
async def astop(self, config: dict[str, Any]) -> dict[str, Any]:
    """Request the current graph execution to stop (async).

    Contract:
    - Requires a valid thread_id in config
    - If no active thread or no checkpointer, returns not-running
    - If state exists and is running, set stop_requested flag in thread info
    """
    cfg = self._prepare_config(config, is_stream=bool(config.get("is_stream", False)))
    if not self._checkpointer:
        return {"ok": False, "reason": "no-checkpointer"}

    # Load state to see if this thread is running
    state = await self._checkpointer.aget_state_cache(
        cfg
    ) or await self._checkpointer.aget_state(cfg)
    if not state:
        return {"ok": False, "running": False, "reason": "no-state"}

    running = state.is_running() and not state.is_interrupted()
    # Set stop flag regardless; handlers will act if running
    if running:
        state.execution_meta.stop_current_execution = StopRequestStatus.STOP_REQUESTED
        # update cache
        # Cache update is enough; state will be picked up by running execution
        # As its running, cache will be available immediately
        await self._checkpointer.aput_state_cache(cfg, state)
        # Fixme: consider putting to main state as well
        # await self._checkpointer.aput_state(cfg, state)
        logger.info("Set stop_current_execution flag for thread_id: %s", cfg.get("thread_id"))
        return {"ok": True, "running": running}

    logger.info(
        "No running execution to stop for thread_id: %s (running=%s, interrupted=%s)",
        cfg.get("thread_id"),
        running,
        state.is_interrupted(),
    )
    return {"ok": True, "running": running, "reason": "not-running"}
astream async
astream(input_data, config=None, response_granularity=ResponseGranularity.LOW)

Execute the graph asynchronously with streaming support.

Yields Message objects containing incremental responses. If nodes return streaming responses, yields them directly. If nodes return complete responses, simulates streaming by chunking.

Parameters:

Name Type Description Default
input_data dict[str, Any]

Input dict

required
config dict[str, Any] | None

Configuration dictionary

None
response_granularity ResponseGranularity

Response parsing granularity

LOW

Yields:

Type Description
AsyncIterator[Message]

Message objects with incremental content

Source code in pyagenity/graph/compiled_graph.py
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
async def astream(
    self,
    input_data: dict[str, Any],
    config: dict[str, Any] | None = None,
    response_granularity: ResponseGranularity = ResponseGranularity.LOW,
) -> AsyncIterator[Message]:
    """Execute the graph asynchronously with streaming support.

    Yields Message objects containing incremental responses.
    If nodes return streaming responses, yields them directly.
    If nodes return complete responses, simulates streaming by chunking.

    Args:
        input_data: Input dict
        config: Configuration dictionary
        response_granularity: Response parsing granularity

    Yields:
        Message objects with incremental content
    """

    cfg = self._prepare_config(config, is_stream=True)

    async for chunk in self._stream_handler.stream(
        input_data,
        cfg,
        self._state,
        response_granularity,
    ):
        yield chunk
generate_graph
generate_graph()

Generate the graph representation.

Returns:

Type Description
dict[str, Any]

A dictionary representing the graph structure.

Source code in pyagenity/graph/compiled_graph.py
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
def generate_graph(self) -> dict[str, Any]:
    """Generate the graph representation.

    Returns:
        A dictionary representing the graph structure.
    """
    graph = {
        "info": {},
        "nodes": [],
        "edges": [],
    }
    # Populate the graph with nodes and edges
    for node_name in self._state_graph.nodes:
        graph["nodes"].append(
            {
                "id": str(uuid4()),
                "name": node_name,
            }
        )

    for edge in self._state_graph.edges:
        graph["edges"].append(
            {
                "id": str(uuid4()),
                "source": edge.from_node,
                "target": edge.to_node,
            }
        )

    # Add few more extra info
    graph["info"] = {
        "node_count": len(graph["nodes"]),
        "edge_count": len(graph["edges"]),
        "checkpointer": self._checkpointer is not None,
        "checkpointer_type": type(self._checkpointer).__name__ if self._checkpointer else None,
        "publisher": self._publisher is not None,
        "store": self._store is not None,
        "interrupt_before": self._interrupt_before,
        "interrupt_after": self._interrupt_after,
        "context_type": self._state_graph._context_manager.__class__.__name__,
        "id_generator": self._state_graph._id_generator.__class__.__name__,
        "id_type": self._state_graph._id_generator.id_type.value,
        "state_type": self._state.__class__.__name__,
        "state_fields": list(self._state.model_dump().keys()),
    }
    return graph
invoke
invoke(input_data, config=None, response_granularity=ResponseGranularity.LOW)

Execute the graph synchronously and return the final results.

Runs the complete graph workflow from start to finish, handling state management, node execution, and result formatting. This method automatically detects whether to start a fresh execution or resume from an interrupted state.

The execution is synchronous but internally uses async operations, making it suitable for use in non-async contexts while still benefiting from async capabilities for I/O operations.

Parameters:

Name Type Description Default
input_data dict[str, Any]

Input dictionary for graph execution. For new executions, should contain 'messages' key with list of initial messages. For resumed executions, can contain additional data to merge.

required
config dict[str, Any] | None

Optional configuration dictionary containing execution settings: - user_id: Identifier for the user/session - thread_id: Unique identifier for this execution thread - run_id: Unique identifier for this specific run - recursion_limit: Maximum steps before stopping (default: 25)

None
response_granularity ResponseGranularity

Level of detail in the response: - LOW: Returns only messages (default) - PARTIAL: Returns context, summary, and messages - FULL: Returns complete state and messages

LOW

Returns:

Type Description
dict[str, Any]

Dictionary containing execution results formatted according to the

dict[str, Any]

specified granularity level. Always includes execution messages

dict[str, Any]

and may include additional state information.

Raises:

Type Description
ValueError

If input_data is invalid for new execution.

GraphRecursionError

If execution exceeds recursion limit.

Various exceptions

Depending on node execution failures.

Example
# Basic execution
result = compiled.invoke({"messages": [Message.text_message("Process this data")]})
print(result["messages"])  # Final execution messages

# With configuration and full details
result = compiled.invoke(
    input_data={"messages": [message]},
    config={"user_id": "user123", "thread_id": "session456", "recursion_limit": 50},
    response_granularity=ResponseGranularity.FULL,
)
print(result["state"])  # Complete final state
Note

This method uses asyncio.run() internally, so it should not be called from within an async context. Use ainvoke() instead for async execution.

Source code in pyagenity/graph/compiled_graph.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
def invoke(
    self,
    input_data: dict[str, Any],
    config: dict[str, Any] | None = None,
    response_granularity: ResponseGranularity = ResponseGranularity.LOW,
) -> dict[str, Any]:
    """Execute the graph synchronously and return the final results.

    Runs the complete graph workflow from start to finish, handling state
    management, node execution, and result formatting. This method automatically
    detects whether to start a fresh execution or resume from an interrupted state.

    The execution is synchronous but internally uses async operations, making it
    suitable for use in non-async contexts while still benefiting from async
    capabilities for I/O operations.

    Args:
        input_data: Input dictionary for graph execution. For new executions,
            should contain 'messages' key with list of initial messages.
            For resumed executions, can contain additional data to merge.
        config: Optional configuration dictionary containing execution settings:
            - user_id: Identifier for the user/session
            - thread_id: Unique identifier for this execution thread
            - run_id: Unique identifier for this specific run
            - recursion_limit: Maximum steps before stopping (default: 25)
        response_granularity: Level of detail in the response:
            - LOW: Returns only messages (default)
            - PARTIAL: Returns context, summary, and messages
            - FULL: Returns complete state and messages

    Returns:
        Dictionary containing execution results formatted according to the
        specified granularity level. Always includes execution messages
        and may include additional state information.

    Raises:
        ValueError: If input_data is invalid for new execution.
        GraphRecursionError: If execution exceeds recursion limit.
        Various exceptions: Depending on node execution failures.

    Example:
        ```python
        # Basic execution
        result = compiled.invoke({"messages": [Message.text_message("Process this data")]})
        print(result["messages"])  # Final execution messages

        # With configuration and full details
        result = compiled.invoke(
            input_data={"messages": [message]},
            config={"user_id": "user123", "thread_id": "session456", "recursion_limit": 50},
            response_granularity=ResponseGranularity.FULL,
        )
        print(result["state"])  # Complete final state
        ```

    Note:
        This method uses asyncio.run() internally, so it should not be called
        from within an async context. Use ainvoke() instead for async execution.
    """
    logger.info(
        "Starting synchronous graph execution with %d input keys, granularity=%s",
        len(input_data) if input_data else 0,
        response_granularity,
    )
    logger.debug("Input data keys: %s", list(input_data.keys()) if input_data else [])
    # Async Will Handle Event Publish

    try:
        result = asyncio.run(self.ainvoke(input_data, config, response_granularity))
        logger.info("Synchronous graph execution completed successfully")
        return result
    except Exception as e:
        logger.exception("Synchronous graph execution failed: %s", e)
        raise
stop
stop(config)

Request the current graph execution to stop (sync helper).

This sets a stop flag in the checkpointer's thread store keyed by thread_id. Handlers periodically check this flag and interrupt execution. Returns a small status dict.

Source code in pyagenity/graph/compiled_graph.py
251
252
253
254
255
256
257
258
def stop(self, config: dict[str, Any]) -> dict[str, Any]:
    """Request the current graph execution to stop (sync helper).

    This sets a stop flag in the checkpointer's thread store keyed by thread_id.
    Handlers periodically check this flag and interrupt execution.
    Returns a small status dict.
    """
    return asyncio.run(self.astop(config))
stream
stream(input_data, config=None, response_granularity=ResponseGranularity.LOW)

Execute the graph synchronously with streaming support.

Yields Message objects containing incremental responses. If nodes return streaming responses, yields them directly. If nodes return complete responses, simulates streaming by chunking.

Parameters:

Name Type Description Default
input_data dict[str, Any]

Input dict

required
config dict[str, Any] | None

Configuration dictionary

None
response_granularity ResponseGranularity

Response parsing granularity

LOW

Yields:

Type Description
Generator[Message]

Message objects with incremental content

Source code in pyagenity/graph/compiled_graph.py
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
def stream(
    self,
    input_data: dict[str, Any],
    config: dict[str, Any] | None = None,
    response_granularity: ResponseGranularity = ResponseGranularity.LOW,
) -> Generator[Message]:
    """Execute the graph synchronously with streaming support.

    Yields Message objects containing incremental responses.
    If nodes return streaming responses, yields them directly.
    If nodes return complete responses, simulates streaming by chunking.

    Args:
        input_data: Input dict
        config: Configuration dictionary
        response_granularity: Response parsing granularity

    Yields:
        Message objects with incremental content
    """

    # For sync streaming, we'll use asyncio.run to handle the async implementation
    async def _async_stream():
        async for chunk in self.astream(input_data, config, response_granularity):
            yield chunk

    # Convert async generator to sync iteration with a dedicated event loop
    gen = _async_stream()
    loop = asyncio.new_event_loop()
    policy = asyncio.get_event_loop_policy()
    try:
        previous_loop = policy.get_event_loop()
    except Exception:
        previous_loop = None
    asyncio.set_event_loop(loop)
    logger.info("Synchronous streaming started")

    try:
        while True:
            try:
                chunk = loop.run_until_complete(gen.__anext__())
                yield chunk
            except StopAsyncIteration:
                break
    finally:
        # Attempt to close the async generator cleanly
        with contextlib.suppress(Exception):
            loop.run_until_complete(gen.aclose())  # type: ignore[attr-defined]
        # Restore previous loop if any, then close created loop
        try:
            if previous_loop is not None:
                asyncio.set_event_loop(previous_loop)
        finally:
            loop.close()
    logger.info("Synchronous streaming completed")
Edge

Represents a connection between two nodes in a graph workflow.

An Edge defines the relationship and routing logic between nodes, specifying how execution should flow from one node to another. Edges can be either static (unconditional) or conditional based on runtime state evaluation.

Edges support complex routing scenarios including: - Simple static connections between nodes - Conditional routing based on state evaluation - Dynamic routing with multiple possible destinations - Decision trees and branching logic

Attributes:

Name Type Description
from_node

Name of the source node where execution originates.

to_node

Name of the destination node where execution continues.

condition

Optional callable that determines if this edge should be followed. If None, the edge is always followed (static edge).

condition_result str | None

Optional value to match against condition result for mapped conditional edges.

Example
# Static edge - always followed
static_edge = Edge("start", "process")


# Conditional edge - followed only if condition returns True
def needs_approval(state):
    return state.data.get("requires_approval", False)


conditional_edge = Edge("process", "approval", condition=needs_approval)


# Mapped conditional edge - follows based on specific condition result
def get_priority(state):
    return state.data.get("priority", "normal")


high_priority_edge = Edge("triage", "urgent", condition=get_priority)
high_priority_edge.condition_result = "high"

Methods:

Name Description
__init__

Initialize a new Edge with source, destination, and optional condition.

Source code in pyagenity/graph/edge.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
class Edge:
    """Represents a connection between two nodes in a graph workflow.

    An Edge defines the relationship and routing logic between nodes, specifying
    how execution should flow from one node to another. Edges can be either
    static (unconditional) or conditional based on runtime state evaluation.

    Edges support complex routing scenarios including:
    - Simple static connections between nodes
    - Conditional routing based on state evaluation
    - Dynamic routing with multiple possible destinations
    - Decision trees and branching logic

    Attributes:
        from_node: Name of the source node where execution originates.
        to_node: Name of the destination node where execution continues.
        condition: Optional callable that determines if this edge should be
            followed. If None, the edge is always followed (static edge).
        condition_result: Optional value to match against condition result
            for mapped conditional edges.

    Example:
        ```python
        # Static edge - always followed
        static_edge = Edge("start", "process")


        # Conditional edge - followed only if condition returns True
        def needs_approval(state):
            return state.data.get("requires_approval", False)


        conditional_edge = Edge("process", "approval", condition=needs_approval)


        # Mapped conditional edge - follows based on specific condition result
        def get_priority(state):
            return state.data.get("priority", "normal")


        high_priority_edge = Edge("triage", "urgent", condition=get_priority)
        high_priority_edge.condition_result = "high"
        ```
    """

    def __init__(
        self,
        from_node: str,
        to_node: str,
        condition: Callable | None = None,
    ):
        """Initialize a new Edge with source, destination, and optional condition.

        Args:
            from_node: Name of the source node. Must match a node name in the graph.
            to_node: Name of the destination node. Must match a node name in the graph
                or be a special constant like END.
            condition: Optional callable that takes an AgentState as argument and
                returns a value to determine if this edge should be followed.
                If None, this is a static edge that's always followed.

        Note:
            The condition function should be deterministic and side-effect free
            for predictable execution behavior. It receives the current AgentState
            and should return a boolean (for simple conditions) or a string/value
            (for mapped conditional routing).
        """
        logger.debug(
            "Creating edge from '%s' to '%s' with condition=%s",
            from_node,
            to_node,
            "yes" if condition else "no",
        )
        self.from_node = from_node
        self.to_node = to_node
        self.condition = condition
        self.condition_result: str | None = None
Attributes
condition instance-attribute
condition = condition
condition_result instance-attribute
condition_result = None
from_node instance-attribute
from_node = from_node
to_node instance-attribute
to_node = to_node
Functions
__init__
__init__(from_node, to_node, condition=None)

Initialize a new Edge with source, destination, and optional condition.

Parameters:

Name Type Description Default
from_node str

Name of the source node. Must match a node name in the graph.

required
to_node str

Name of the destination node. Must match a node name in the graph or be a special constant like END.

required
condition Callable | None

Optional callable that takes an AgentState as argument and returns a value to determine if this edge should be followed. If None, this is a static edge that's always followed.

None
Note

The condition function should be deterministic and side-effect free for predictable execution behavior. It receives the current AgentState and should return a boolean (for simple conditions) or a string/value (for mapped conditional routing).

Source code in pyagenity/graph/edge.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def __init__(
    self,
    from_node: str,
    to_node: str,
    condition: Callable | None = None,
):
    """Initialize a new Edge with source, destination, and optional condition.

    Args:
        from_node: Name of the source node. Must match a node name in the graph.
        to_node: Name of the destination node. Must match a node name in the graph
            or be a special constant like END.
        condition: Optional callable that takes an AgentState as argument and
            returns a value to determine if this edge should be followed.
            If None, this is a static edge that's always followed.

    Note:
        The condition function should be deterministic and side-effect free
        for predictable execution behavior. It receives the current AgentState
        and should return a boolean (for simple conditions) or a string/value
        (for mapped conditional routing).
    """
    logger.debug(
        "Creating edge from '%s' to '%s' with condition=%s",
        from_node,
        to_node,
        "yes" if condition else "no",
    )
    self.from_node = from_node
    self.to_node = to_node
    self.condition = condition
    self.condition_result: str | None = None
Node

Represents a node in the graph workflow.

A Node encapsulates a function or ToolNode that can be executed as part of a graph workflow. It handles dependency injection, parameter mapping, and execution context management.

The Node class supports both regular callable functions and ToolNode instances for handling tool-based operations. It automatically injects dependencies based on function signatures and provides legacy parameter support.

Attributes:

Name Type Description
name str

Unique identifier for the node within the graph.

func Union[Callable, ToolNode]

The function or ToolNode to execute.

Example

def my_function(state, config): ... return {"result": "processed"} node = Node("processor", my_function) result = await node.execute(state, config)

Methods:

Name Description
__init__

Initialize a new Node instance with function and dependencies.

execute

Execute the node function with comprehensive context and callback support.

stream

Execute the node function with streaming output support.

Source code in pyagenity/graph/node.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
class Node:
    """Represents a node in the graph workflow.

    A Node encapsulates a function or ToolNode that can be executed as part of
    a graph workflow. It handles dependency injection, parameter mapping, and
    execution context management.

    The Node class supports both regular callable functions and ToolNode instances
    for handling tool-based operations. It automatically injects dependencies
    based on function signatures and provides legacy parameter support.

    Attributes:
        name (str): Unique identifier for the node within the graph.
        func (Union[Callable, ToolNode]): The function or ToolNode to execute.

    Example:
        >>> def my_function(state, config):
        ...     return {"result": "processed"}
        >>> node = Node("processor", my_function)
        >>> result = await node.execute(state, config)
    """

    def __init__(
        self,
        name: str,
        func: Union[Callable, "ToolNode"],
        publisher: BasePublisher | None = Inject[BasePublisher],
    ):
        """Initialize a new Node instance with function and dependencies.

        Args:
            name: Unique identifier for the node within the graph. This name
                is used for routing, logging, and referencing the node in
                graph configuration.
            func: The function or ToolNode to execute when this node is called.
                Functions should accept at least 'state' and 'config' parameters.
                ToolNode instances handle tool-based operations and provide
                their own execution logic.
            publisher: Optional event publisher for execution monitoring.
                Injected via dependency injection if not explicitly provided.
                Used for publishing node execution events and status updates.

        Note:
            The function signature is automatically analyzed to determine
            required parameters and dependency injection points. Parameters
            matching injectable service names will be automatically provided
            by the framework during execution.
        """
        logger.debug(
            "Initializing node '%s' with func=%s",
            name,
            getattr(func, "__name__", type(func).__name__),
        )
        self.name = name
        self.func = func
        self.publisher = publisher
        self.invoke_handler = InvokeNodeHandler(
            name,
            func,
        )

        self.stream_handler = StreamNodeHandler(
            name,
            func,
        )

    async def execute(
        self,
        config: dict[str, Any],
        state: AgentState,
        callback_mgr: CallbackManager = Inject[CallbackManager],
    ) -> dict[str, Any] | list[Message]:
        """Execute the node function with comprehensive context and callback support.

        Executes the node's function or ToolNode with full dependency injection,
        callback hook execution, and error handling. This method provides the
        complete execution environment including state access, configuration,
        and injected services.

        Args:
            config: Configuration dictionary containing execution context,
                user settings, thread identification, and runtime parameters.
            state: Current AgentState providing workflow context, message history,
                and shared state information accessible to the node function.
            callback_mgr: Callback manager for executing pre/post execution hooks.
                Injected via dependency injection if not explicitly provided.

        Returns:
            Either a dictionary containing updated state and execution results,
            or a list of Message objects representing the node's output.
            The return type depends on the node function's implementation.

        Raises:
            Various exceptions depending on node function behavior. All exceptions
            are handled by the callback manager's error handling hooks before
            being propagated.

        Example:
            ```python
            # Node function that returns messages
            def process_data(state, config):
                result = process(state.data)
                return [Message.text_message(f"Processed: {result}")]


            node = Node("processor", process_data)
            messages = await node.execute(config, state)
            ```

        Note:
            The node function receives dependency-injected parameters based on
            its signature. Common injectable parameters include 'state', 'config',
            'context_manager', 'publisher', and other framework services.
        """
        return await self.invoke_handler.invoke(
            config,
            state,
            callback_mgr,
        )

    async def stream(
        self,
        config: dict[str, Any],
        state: AgentState,
        callback_mgr: CallbackManager = Inject[CallbackManager],
    ) -> AsyncIterable[dict[str, Any] | Message]:
        """Execute the node function with streaming output support.

        Similar to execute() but designed for streaming scenarios where the node
        function can produce incremental results. This method provides an async
        iterator interface over the node's outputs, allowing for real-time
        processing and response streaming.

        Args:
            config: Configuration dictionary with execution context and settings.
            state: Current AgentState providing workflow context and shared state.
            callback_mgr: Callback manager for pre/post execution hook handling.

        Yields:
            Dictionary objects or Message instances representing incremental
            outputs from the node function. The exact type and frequency of
            yields depends on the node function's streaming implementation.

        Example:
            ```python
            async def streaming_processor(state, config):
                for item in large_dataset:
                    result = process_item(item)
                    yield Message.text_message(f"Processed item: {result}")


            node = Node("stream_processor", streaming_processor)
            async for output in node.stream(config, state):
                print(f"Streamed: {output.content}")
            ```

        Note:
            Not all node functions support streaming. For non-streaming functions,
            this method will yield a single result equivalent to calling execute().
            The streaming capability is determined by the node function's implementation.
        """
        result = self.stream_handler.stream(
            config,
            state,
            callback_mgr,
        )

        async for item in result:
            yield item
Attributes
func instance-attribute
func = func
invoke_handler instance-attribute
invoke_handler = InvokeNodeHandler(name, func)
name instance-attribute
name = name
publisher instance-attribute
publisher = publisher
stream_handler instance-attribute
stream_handler = StreamNodeHandler(name, func)
Functions
__init__
__init__(name, func, publisher=Inject[BasePublisher])

Initialize a new Node instance with function and dependencies.

Parameters:

Name Type Description Default
name str

Unique identifier for the node within the graph. This name is used for routing, logging, and referencing the node in graph configuration.

required
func Union[Callable, ToolNode]

The function or ToolNode to execute when this node is called. Functions should accept at least 'state' and 'config' parameters. ToolNode instances handle tool-based operations and provide their own execution logic.

required
publisher BasePublisher | None

Optional event publisher for execution monitoring. Injected via dependency injection if not explicitly provided. Used for publishing node execution events and status updates.

Inject[BasePublisher]
Note

The function signature is automatically analyzed to determine required parameters and dependency injection points. Parameters matching injectable service names will be automatically provided by the framework during execution.

Source code in pyagenity/graph/node.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def __init__(
    self,
    name: str,
    func: Union[Callable, "ToolNode"],
    publisher: BasePublisher | None = Inject[BasePublisher],
):
    """Initialize a new Node instance with function and dependencies.

    Args:
        name: Unique identifier for the node within the graph. This name
            is used for routing, logging, and referencing the node in
            graph configuration.
        func: The function or ToolNode to execute when this node is called.
            Functions should accept at least 'state' and 'config' parameters.
            ToolNode instances handle tool-based operations and provide
            their own execution logic.
        publisher: Optional event publisher for execution monitoring.
            Injected via dependency injection if not explicitly provided.
            Used for publishing node execution events and status updates.

    Note:
        The function signature is automatically analyzed to determine
        required parameters and dependency injection points. Parameters
        matching injectable service names will be automatically provided
        by the framework during execution.
    """
    logger.debug(
        "Initializing node '%s' with func=%s",
        name,
        getattr(func, "__name__", type(func).__name__),
    )
    self.name = name
    self.func = func
    self.publisher = publisher
    self.invoke_handler = InvokeNodeHandler(
        name,
        func,
    )

    self.stream_handler = StreamNodeHandler(
        name,
        func,
    )
execute async
execute(config, state, callback_mgr=Inject[CallbackManager])

Execute the node function with comprehensive context and callback support.

Executes the node's function or ToolNode with full dependency injection, callback hook execution, and error handling. This method provides the complete execution environment including state access, configuration, and injected services.

Parameters:

Name Type Description Default
config dict[str, Any]

Configuration dictionary containing execution context, user settings, thread identification, and runtime parameters.

required
state AgentState

Current AgentState providing workflow context, message history, and shared state information accessible to the node function.

required
callback_mgr CallbackManager

Callback manager for executing pre/post execution hooks. Injected via dependency injection if not explicitly provided.

Inject[CallbackManager]

Returns:

Type Description
dict[str, Any] | list[Message]

Either a dictionary containing updated state and execution results,

dict[str, Any] | list[Message]

or a list of Message objects representing the node's output.

dict[str, Any] | list[Message]

The return type depends on the node function's implementation.

Example
# Node function that returns messages
def process_data(state, config):
    result = process(state.data)
    return [Message.text_message(f"Processed: {result}")]


node = Node("processor", process_data)
messages = await node.execute(config, state)
Note

The node function receives dependency-injected parameters based on its signature. Common injectable parameters include 'state', 'config', 'context_manager', 'publisher', and other framework services.

Source code in pyagenity/graph/node.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
async def execute(
    self,
    config: dict[str, Any],
    state: AgentState,
    callback_mgr: CallbackManager = Inject[CallbackManager],
) -> dict[str, Any] | list[Message]:
    """Execute the node function with comprehensive context and callback support.

    Executes the node's function or ToolNode with full dependency injection,
    callback hook execution, and error handling. This method provides the
    complete execution environment including state access, configuration,
    and injected services.

    Args:
        config: Configuration dictionary containing execution context,
            user settings, thread identification, and runtime parameters.
        state: Current AgentState providing workflow context, message history,
            and shared state information accessible to the node function.
        callback_mgr: Callback manager for executing pre/post execution hooks.
            Injected via dependency injection if not explicitly provided.

    Returns:
        Either a dictionary containing updated state and execution results,
        or a list of Message objects representing the node's output.
        The return type depends on the node function's implementation.

    Raises:
        Various exceptions depending on node function behavior. All exceptions
        are handled by the callback manager's error handling hooks before
        being propagated.

    Example:
        ```python
        # Node function that returns messages
        def process_data(state, config):
            result = process(state.data)
            return [Message.text_message(f"Processed: {result}")]


        node = Node("processor", process_data)
        messages = await node.execute(config, state)
        ```

    Note:
        The node function receives dependency-injected parameters based on
        its signature. Common injectable parameters include 'state', 'config',
        'context_manager', 'publisher', and other framework services.
    """
    return await self.invoke_handler.invoke(
        config,
        state,
        callback_mgr,
    )
stream async
stream(config, state, callback_mgr=Inject[CallbackManager])

Execute the node function with streaming output support.

Similar to execute() but designed for streaming scenarios where the node function can produce incremental results. This method provides an async iterator interface over the node's outputs, allowing for real-time processing and response streaming.

Parameters:

Name Type Description Default
config dict[str, Any]

Configuration dictionary with execution context and settings.

required
state AgentState

Current AgentState providing workflow context and shared state.

required
callback_mgr CallbackManager

Callback manager for pre/post execution hook handling.

Inject[CallbackManager]

Yields:

Type Description
AsyncIterable[dict[str, Any] | Message]

Dictionary objects or Message instances representing incremental

AsyncIterable[dict[str, Any] | Message]

outputs from the node function. The exact type and frequency of

AsyncIterable[dict[str, Any] | Message]

yields depends on the node function's streaming implementation.

Example
async def streaming_processor(state, config):
    for item in large_dataset:
        result = process_item(item)
        yield Message.text_message(f"Processed item: {result}")


node = Node("stream_processor", streaming_processor)
async for output in node.stream(config, state):
    print(f"Streamed: {output.content}")
Note

Not all node functions support streaming. For non-streaming functions, this method will yield a single result equivalent to calling execute(). The streaming capability is determined by the node function's implementation.

Source code in pyagenity/graph/node.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
async def stream(
    self,
    config: dict[str, Any],
    state: AgentState,
    callback_mgr: CallbackManager = Inject[CallbackManager],
) -> AsyncIterable[dict[str, Any] | Message]:
    """Execute the node function with streaming output support.

    Similar to execute() but designed for streaming scenarios where the node
    function can produce incremental results. This method provides an async
    iterator interface over the node's outputs, allowing for real-time
    processing and response streaming.

    Args:
        config: Configuration dictionary with execution context and settings.
        state: Current AgentState providing workflow context and shared state.
        callback_mgr: Callback manager for pre/post execution hook handling.

    Yields:
        Dictionary objects or Message instances representing incremental
        outputs from the node function. The exact type and frequency of
        yields depends on the node function's streaming implementation.

    Example:
        ```python
        async def streaming_processor(state, config):
            for item in large_dataset:
                result = process_item(item)
                yield Message.text_message(f"Processed item: {result}")


        node = Node("stream_processor", streaming_processor)
        async for output in node.stream(config, state):
            print(f"Streamed: {output.content}")
        ```

    Note:
        Not all node functions support streaming. For non-streaming functions,
        this method will yield a single result equivalent to calling execute().
        The streaming capability is determined by the node function's implementation.
    """
    result = self.stream_handler.stream(
        config,
        state,
        callback_mgr,
    )

    async for item in result:
        yield item
StateGraph

Main graph class for orchestrating multi-agent workflows.

This class provides the core functionality for building and managing stateful agent workflows. It is similar to LangGraph's StateGraph integration with support for dependency injection.

The graph is generic over state types to support custom AgentState subclasses, allowing for type-safe state management throughout the workflow execution.

Attributes:

Name Type Description
state StateT

The current state of the graph workflow.

nodes dict[str, Node]

Collection of nodes in the graph.

edges list[Edge]

Collection of edges connecting nodes.

entry_point str | None

Name of the starting node for execution.

context_manager BaseContextManager[StateT] | None

Optional context manager for handling cross-node state operations.

dependency_container DependencyContainer

Container for managing dependencies that can be injected into node functions.

compiled bool

Whether the graph has been compiled for execution.

Example

graph = StateGraph() graph.add_node("process", process_function) graph.add_edge(START, "process") graph.add_edge("process", END) compiled = graph.compile() result = compiled.invoke({"input": "data"})

Methods:

Name Description
__init__

Initialize a new StateGraph instance.

add_conditional_edges

Add conditional routing between nodes based on runtime evaluation.

add_edge

Add a static edge between two nodes.

add_node

Add a node to the graph.

compile

Compile the graph for execution.

set_entry_point

Set the entry point for the graph.

Source code in pyagenity/graph/state_graph.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
class StateGraph[StateT: AgentState]:
    """Main graph class for orchestrating multi-agent workflows.

    This class provides the core functionality for building and managing stateful
    agent workflows. It is similar to LangGraph's StateGraph
    integration with support for dependency injection.

    The graph is generic over state types to support custom AgentState subclasses,
    allowing for type-safe state management throughout the workflow execution.

    Attributes:
        state (StateT): The current state of the graph workflow.
        nodes (dict[str, Node]): Collection of nodes in the graph.
        edges (list[Edge]): Collection of edges connecting nodes.
        entry_point (str | None): Name of the starting node for execution.
        context_manager (BaseContextManager[StateT] | None): Optional context manager
            for handling cross-node state operations.
        dependency_container (DependencyContainer): Container for managing
            dependencies that can be injected into node functions.
        compiled (bool): Whether the graph has been compiled for execution.

    Example:
        >>> graph = StateGraph()
        >>> graph.add_node("process", process_function)
        >>> graph.add_edge(START, "process")
        >>> graph.add_edge("process", END)
        >>> compiled = graph.compile()
        >>> result = compiled.invoke({"input": "data"})
    """

    def __init__(
        self,
        state: StateT | None = None,
        context_manager: BaseContextManager[StateT] | None = None,
        publisher: BasePublisher | None = None,
        id_generator: BaseIDGenerator = DefaultIDGenerator(),
        container: InjectQ | None = None,
        thread_name_generator: Callable[[], str] | None = None,
    ):
        """Initialize a new StateGraph instance.

        Args:
            state: Initial state for the graph. If None, a default AgentState
                will be created.
            context_manager: Optional context manager for handling cross-node
                state operations and advanced state management patterns.
            dependency_container: Container for managing dependencies that can
                be injected into node functions. If None, a new empty container
                will be created.
            publisher: Publisher for emitting events during execution

        Note:
            START and END nodes are automatically added to the graph upon
            initialization and accept the full node signature including
            dependencies.

        Example:
            # Basic usage with default AgentState
            >>> graph = StateGraph()

            # With custom state
            >>> custom_state = MyCustomState()
            >>> graph = StateGraph(custom_state)

            # Or using type hints for clarity
            >>> graph = StateGraph[MyCustomState](MyCustomState())
        """
        logger.info("Initializing StateGraph")
        logger.debug(
            "StateGraph init with state=%s, context_manager=%s",
            type(state).__name__ if state else "default AgentState",
            type(context_manager).__name__ if context_manager else None,
        )

        # State handling
        self._state: StateT = state if state else AgentState()  # type: ignore[assignment]

        # Graph structure
        self.nodes: dict[str, Node] = {}
        self.edges: list[Edge] = []
        self.entry_point: str | None = None

        # Services
        self._publisher: BasePublisher | None = publisher
        self._id_generator: BaseIDGenerator = id_generator
        self._context_manager: BaseContextManager[StateT] | None = context_manager
        self.thread_name_generator = thread_name_generator
        # save container for dependency injection
        # if any container is passed then we will activate that
        # otherwise we can skip it and use the default one
        if container is None:
            self._container = InjectQ.get_instance()
            logger.debug("No container provided, using global singleton instance")
        else:
            logger.debug("Using provided dependency container instance")
            self._container = container
            self._container.activate()

        # Register task_manager, for async tasks
        # This will be used to run background tasks
        self._task_manager = BackgroundTaskManager()

        # now setup the graph
        self._setup()

        # Add START and END nodes (accept full node signature including dependencies)
        logger.debug("Adding default START and END nodes")
        self.nodes[START] = Node(START, lambda state, config, **deps: state, self._publisher)  # type: ignore
        self.nodes[END] = Node(END, lambda state, config, **deps: state, self._publisher)
        logger.debug("StateGraph initialized with %d nodes", len(self.nodes))

    def _setup(self):
        """Setup the graph before compilation.

        This method can be used to perform any necessary setup or validation
        before compiling the graph for execution.
        """
        logger.info("Setting up StateGraph before compilation")
        # Placeholder for any setup logic needed before compilation
        # register dependencies

        # register state and context manager as singletons (these are nullable)
        self._container.bind_instance(
            BaseContextManager,
            self._context_manager,
            allow_none=True,
            allow_concrete=True,
        )
        self._container.bind_instance(
            BasePublisher,
            self._publisher,
            allow_none=True,
            allow_concrete=True,
        )

        # register id generator as factory
        self._container.bind_instance(
            BaseIDGenerator,
            self._id_generator,
            allow_concrete=True,
        )
        self._container.bind("generated_id_type", self._id_generator.id_type)
        # Allow async method also
        self._container.bind_factory(
            "generated_id",
            lambda: self._id_generator.generate(),
        )

        # Attach Thread name generator if provided
        if self.thread_name_generator is None:
            self.thread_name_generator = generate_dummy_thread_name

        generator = self.thread_name_generator or generate_dummy_thread_name

        self._container.bind_factory(
            "generated_thread_name",
            lambda: generator(),
        )

        # Save BackgroundTaskManager
        self._container.bind_instance(
            BackgroundTaskManager,
            self._task_manager,
            allow_concrete=False,
        )

    def add_node(
        self,
        name_or_func: str | Callable,
        func: Union[Callable, "ToolNode", None] = None,
    ) -> "StateGraph":
        """Add a node to the graph.

        This method supports two calling patterns:
        1. Pass a callable as the first argument (name inferred from function name)
        2. Pass a name string and callable/ToolNode as separate arguments

        Args:
            name_or_func: Either the node name (str) or a callable function.
                If callable, the function name will be used as the node name.
            func: The function or ToolNode to execute. Required if name_or_func
                is a string, ignored if name_or_func is callable.

        Returns:
            StateGraph: The graph instance for method chaining.

        Raises:
            ValueError: If invalid arguments are provided.

        Example:
            >>> # Method 1: Function name inferred
            >>> graph.add_node(my_function)
            >>> # Method 2: Explicit name and function
            >>> graph.add_node("process", my_function)
        """
        if callable(name_or_func) and func is None:
            # Function passed as first argument
            name = name_or_func.__name__
            func = name_or_func
            logger.debug("Adding node '%s' with inferred name from function", name)
        elif isinstance(name_or_func, str) and (callable(func) or isinstance(func, ToolNode)):
            # Name and function passed separately
            name = name_or_func
            logger.debug(
                "Adding node '%s' with explicit name and %s",
                name,
                "ToolNode" if isinstance(func, ToolNode) else "callable",
            )
        else:
            error_msg = "Invalid arguments for add_node"
            logger.error(error_msg)
            raise ValueError(error_msg)

        self.nodes[name] = Node(name, func)
        logger.info("Added node '%s' to graph (total nodes: %d)", name, len(self.nodes))
        return self

    def add_edge(
        self,
        from_node: str,
        to_node: str,
    ) -> "StateGraph":
        """Add a static edge between two nodes.

        Creates a direct connection from one node to another. If the source
        node is START, the target node becomes the entry point for the graph.

        Args:
            from_node: Name of the source node.
            to_node: Name of the target node.

        Returns:
            StateGraph: The graph instance for method chaining.

        Example:
            >>> graph.add_edge("node1", "node2")
            >>> graph.add_edge(START, "entry_node")  # Sets entry point
        """
        logger.debug("Adding edge from '%s' to '%s'", from_node, to_node)
        # Set entry point if edge is from START
        if from_node == START:
            self.entry_point = to_node
            logger.info("Set entry point to '%s'", to_node)
        self.edges.append(Edge(from_node, to_node))
        logger.debug("Added edge (total edges: %d)", len(self.edges))
        return self

    def add_conditional_edges(
        self,
        from_node: str,
        condition: Callable,
        path_map: dict[str, str] | None = None,
    ) -> "StateGraph":
        """Add conditional routing between nodes based on runtime evaluation.

        Creates dynamic routing logic where the next node is determined by evaluating
        a condition function against the current state. This enables complex branching
        logic, decision trees, and adaptive workflow routing.

        Args:
            from_node: Name of the source node where the condition is evaluated.
            condition: Callable function that takes the current AgentState and returns
                a value used for routing decisions. Should be deterministic and
                side-effect free.
            path_map: Optional dictionary mapping condition results to destination nodes.
                If provided, the condition's return value is looked up in this mapping.
                If None, the condition should return the destination node name directly.

        Returns:
            StateGraph: The graph instance for method chaining.

        Raises:
            ValueError: If the condition function or path_map configuration is invalid.

        Example:
            ```python
            # Direct routing - condition returns node name
            def route_by_priority(state):
                priority = state.data.get("priority", "normal")
                return "urgent_handler" if priority == "high" else "normal_handler"


            graph.add_conditional_edges("classifier", route_by_priority)


            # Mapped routing - condition result mapped to nodes
            def get_category(state):
                return state.data.get("category", "default")


            category_map = {
                "finance": "finance_processor",
                "legal": "legal_processor",
                "default": "general_processor",
            }
            graph.add_conditional_edges("categorizer", get_category, category_map)
            ```

        Note:
            The condition function receives the current AgentState and should return
            consistent results for the same state. If using path_map, ensure the
            condition's return values match the map keys exactly.
        """
        """Add conditional edges from a node based on a condition function.

        Creates edges that are traversed based on the result of a condition
        function. The condition function receives the current state and should
        return a value that determines which edge to follow.

        Args:
            from_node: Name of the source node.
            condition: Function that evaluates the current state and returns
                a value to determine the next node.
            path_map: Optional mapping from condition results to target nodes.
                If provided, creates multiple conditional edges. If None,
                creates a single conditional edge.

        Returns:
            StateGraph: The graph instance for method chaining.

        Example:
            >>> def route_condition(state):
            ...     return "success" if state.success else "failure"
            >>> graph.add_conditional_edges(
            ...     "processor",
            ...     route_condition,
            ...     {"success": "next_step", "failure": "error_handler"},
            ... )
        """
        # Create edges based on possible returns from condition function
        logger.debug(
            "Node '%s' adding conditional edges with path_map: %s",
            from_node,
            path_map,
        )
        if path_map:
            logger.debug(
                "Node '%s' adding conditional edges with path_map: %s", from_node, path_map
            )
            for condition_result, target_node in path_map.items():
                edge = Edge(from_node, target_node, condition)
                edge.condition_result = condition_result
                self.edges.append(edge)
        else:
            # Single conditional edge
            logger.debug("Node '%s' adding single conditional edge", from_node)
            self.edges.append(Edge(from_node, "", condition))
        return self

    def set_entry_point(self, node_name: str) -> "StateGraph":
        """Set the entry point for the graph."""
        self.entry_point = node_name
        self.add_edge(START, node_name)
        logger.info("Set entry point to '%s'", node_name)
        return self

    def compile(
        self,
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> "CompiledGraph[StateT]":
        """Compile the graph for execution.

        Args:
            checkpointer: Checkpointer for state persistence
            store: Store for additional data
            debug: Enable debug mode
            interrupt_before: List of node names to interrupt before execution
            interrupt_after: List of node names to interrupt after execution
            callback_manager: Callback manager for executing hooks
        """
        logger.info(
            "Compiling graph with %d nodes, %d edges, entry_point='%s'",
            len(self.nodes),
            len(self.edges),
            self.entry_point,
        )
        logger.debug(
            "Compile options: interrupt_before=%s, interrupt_after=%s",
            interrupt_before,
            interrupt_after,
        )

        if not self.entry_point:
            error_msg = "No entry point set. Use set_entry_point() or add an edge from START."
            logger.error(error_msg)
            raise GraphError(error_msg)

        # Validate graph structure
        logger.debug("Validating graph structure")
        self._validate_graph()
        logger.debug("Graph structure validated successfully")

        # Validate interrupt node names
        interrupt_before = interrupt_before or []
        interrupt_after = interrupt_after or []

        all_interrupt_nodes = set(interrupt_before + interrupt_after)
        invalid_nodes = all_interrupt_nodes - set(self.nodes.keys())
        if invalid_nodes:
            error_msg = f"Invalid interrupt nodes: {invalid_nodes}. Must be existing node names."
            logger.error(error_msg)
            raise GraphError(error_msg)

        self.compiled = True
        logger.info("Graph compilation completed successfully")
        # Import here to avoid circular import at module import time
        # Now update Checkpointer
        if checkpointer is None:
            from pyagenity.checkpointer import InMemoryCheckpointer

            checkpointer = InMemoryCheckpointer[StateT]()
            logger.debug("No checkpointer provided, using InMemoryCheckpointer")

        # Import the CompiledGraph class
        from .compiled_graph import CompiledGraph

        # Setup dependencies
        self._container.bind_instance(
            BaseCheckpointer,
            checkpointer,
            allow_concrete=True,
        )  # not null as we set default
        self._container.bind_instance(
            BaseStore,
            store,
            allow_none=True,
            allow_concrete=True,
        )
        self._container.bind_instance(
            CallbackManager,
            callback_manager,
            allow_concrete=True,
        )  # not null as we set default
        self._container.bind("interrupt_before", interrupt_before)
        self._container.bind("interrupt_after", interrupt_after)
        self._container.bind_instance(StateGraph, self)

        app = CompiledGraph(
            state=self._state,
            interrupt_after=interrupt_after,
            interrupt_before=interrupt_before,
            state_graph=self,
            checkpointer=checkpointer,
            publisher=self._publisher,
            store=store,
            task_manager=self._task_manager,
        )

        self._container.bind(CompiledGraph, app)
        # Compile the Graph, so it will optimize the dependency graph
        self._container.compile()
        return app

    def _validate_graph(self):
        """Validate the graph structure."""
        # Check for orphaned nodes
        connected_nodes = set()
        for edge in self.edges:
            connected_nodes.add(edge.from_node)
            connected_nodes.add(edge.to_node)

        all_nodes = set(self.nodes.keys())
        orphaned = all_nodes - connected_nodes
        if orphaned - {START, END}:  # START and END can be orphaned
            logger.error("Orphaned nodes detected: %s", orphaned - {START, END})
            raise GraphError(f"Orphaned nodes detected: {orphaned - {START, END}}")

        # Check that all edge targets exist
        for edge in self.edges:
            if edge.to_node and edge.to_node not in self.nodes:
                logger.error("Edge '%s' targets non-existent node: %s", edge, edge.to_node)
                raise GraphError(f"Edge targets non-existent node: {edge.to_node}")
Attributes
edges instance-attribute
edges = []
entry_point instance-attribute
entry_point = None
nodes instance-attribute
nodes = {}
thread_name_generator instance-attribute
thread_name_generator = thread_name_generator
Functions
__init__
__init__(state=None, context_manager=None, publisher=None, id_generator=DefaultIDGenerator(), container=None, thread_name_generator=None)

Initialize a new StateGraph instance.

Parameters:

Name Type Description Default
state StateT | None

Initial state for the graph. If None, a default AgentState will be created.

None
context_manager BaseContextManager[StateT] | None

Optional context manager for handling cross-node state operations and advanced state management patterns.

None
dependency_container

Container for managing dependencies that can be injected into node functions. If None, a new empty container will be created.

required
publisher BasePublisher | None

Publisher for emitting events during execution

None
Note

START and END nodes are automatically added to the graph upon initialization and accept the full node signature including dependencies.

Example
Basic usage with default AgentState

graph = StateGraph()

With custom state

custom_state = MyCustomState() graph = StateGraph(custom_state)

Or using type hints for clarity

graph = StateGraphMyCustomState

Source code in pyagenity/graph/state_graph.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def __init__(
    self,
    state: StateT | None = None,
    context_manager: BaseContextManager[StateT] | None = None,
    publisher: BasePublisher | None = None,
    id_generator: BaseIDGenerator = DefaultIDGenerator(),
    container: InjectQ | None = None,
    thread_name_generator: Callable[[], str] | None = None,
):
    """Initialize a new StateGraph instance.

    Args:
        state: Initial state for the graph. If None, a default AgentState
            will be created.
        context_manager: Optional context manager for handling cross-node
            state operations and advanced state management patterns.
        dependency_container: Container for managing dependencies that can
            be injected into node functions. If None, a new empty container
            will be created.
        publisher: Publisher for emitting events during execution

    Note:
        START and END nodes are automatically added to the graph upon
        initialization and accept the full node signature including
        dependencies.

    Example:
        # Basic usage with default AgentState
        >>> graph = StateGraph()

        # With custom state
        >>> custom_state = MyCustomState()
        >>> graph = StateGraph(custom_state)

        # Or using type hints for clarity
        >>> graph = StateGraph[MyCustomState](MyCustomState())
    """
    logger.info("Initializing StateGraph")
    logger.debug(
        "StateGraph init with state=%s, context_manager=%s",
        type(state).__name__ if state else "default AgentState",
        type(context_manager).__name__ if context_manager else None,
    )

    # State handling
    self._state: StateT = state if state else AgentState()  # type: ignore[assignment]

    # Graph structure
    self.nodes: dict[str, Node] = {}
    self.edges: list[Edge] = []
    self.entry_point: str | None = None

    # Services
    self._publisher: BasePublisher | None = publisher
    self._id_generator: BaseIDGenerator = id_generator
    self._context_manager: BaseContextManager[StateT] | None = context_manager
    self.thread_name_generator = thread_name_generator
    # save container for dependency injection
    # if any container is passed then we will activate that
    # otherwise we can skip it and use the default one
    if container is None:
        self._container = InjectQ.get_instance()
        logger.debug("No container provided, using global singleton instance")
    else:
        logger.debug("Using provided dependency container instance")
        self._container = container
        self._container.activate()

    # Register task_manager, for async tasks
    # This will be used to run background tasks
    self._task_manager = BackgroundTaskManager()

    # now setup the graph
    self._setup()

    # Add START and END nodes (accept full node signature including dependencies)
    logger.debug("Adding default START and END nodes")
    self.nodes[START] = Node(START, lambda state, config, **deps: state, self._publisher)  # type: ignore
    self.nodes[END] = Node(END, lambda state, config, **deps: state, self._publisher)
    logger.debug("StateGraph initialized with %d nodes", len(self.nodes))
add_conditional_edges
add_conditional_edges(from_node, condition, path_map=None)

Add conditional routing between nodes based on runtime evaluation.

Creates dynamic routing logic where the next node is determined by evaluating a condition function against the current state. This enables complex branching logic, decision trees, and adaptive workflow routing.

Parameters:

Name Type Description Default
from_node str

Name of the source node where the condition is evaluated.

required
condition Callable

Callable function that takes the current AgentState and returns a value used for routing decisions. Should be deterministic and side-effect free.

required
path_map dict[str, str] | None

Optional dictionary mapping condition results to destination nodes. If provided, the condition's return value is looked up in this mapping. If None, the condition should return the destination node name directly.

None

Returns:

Name Type Description
StateGraph StateGraph

The graph instance for method chaining.

Raises:

Type Description
ValueError

If the condition function or path_map configuration is invalid.

Example
# Direct routing - condition returns node name
def route_by_priority(state):
    priority = state.data.get("priority", "normal")
    return "urgent_handler" if priority == "high" else "normal_handler"


graph.add_conditional_edges("classifier", route_by_priority)


# Mapped routing - condition result mapped to nodes
def get_category(state):
    return state.data.get("category", "default")


category_map = {
    "finance": "finance_processor",
    "legal": "legal_processor",
    "default": "general_processor",
}
graph.add_conditional_edges("categorizer", get_category, category_map)
Note

The condition function receives the current AgentState and should return consistent results for the same state. If using path_map, ensure the condition's return values match the map keys exactly.

Source code in pyagenity/graph/state_graph.py
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
def add_conditional_edges(
    self,
    from_node: str,
    condition: Callable,
    path_map: dict[str, str] | None = None,
) -> "StateGraph":
    """Add conditional routing between nodes based on runtime evaluation.

    Creates dynamic routing logic where the next node is determined by evaluating
    a condition function against the current state. This enables complex branching
    logic, decision trees, and adaptive workflow routing.

    Args:
        from_node: Name of the source node where the condition is evaluated.
        condition: Callable function that takes the current AgentState and returns
            a value used for routing decisions. Should be deterministic and
            side-effect free.
        path_map: Optional dictionary mapping condition results to destination nodes.
            If provided, the condition's return value is looked up in this mapping.
            If None, the condition should return the destination node name directly.

    Returns:
        StateGraph: The graph instance for method chaining.

    Raises:
        ValueError: If the condition function or path_map configuration is invalid.

    Example:
        ```python
        # Direct routing - condition returns node name
        def route_by_priority(state):
            priority = state.data.get("priority", "normal")
            return "urgent_handler" if priority == "high" else "normal_handler"


        graph.add_conditional_edges("classifier", route_by_priority)


        # Mapped routing - condition result mapped to nodes
        def get_category(state):
            return state.data.get("category", "default")


        category_map = {
            "finance": "finance_processor",
            "legal": "legal_processor",
            "default": "general_processor",
        }
        graph.add_conditional_edges("categorizer", get_category, category_map)
        ```

    Note:
        The condition function receives the current AgentState and should return
        consistent results for the same state. If using path_map, ensure the
        condition's return values match the map keys exactly.
    """
    """Add conditional edges from a node based on a condition function.

    Creates edges that are traversed based on the result of a condition
    function. The condition function receives the current state and should
    return a value that determines which edge to follow.

    Args:
        from_node: Name of the source node.
        condition: Function that evaluates the current state and returns
            a value to determine the next node.
        path_map: Optional mapping from condition results to target nodes.
            If provided, creates multiple conditional edges. If None,
            creates a single conditional edge.

    Returns:
        StateGraph: The graph instance for method chaining.

    Example:
        >>> def route_condition(state):
        ...     return "success" if state.success else "failure"
        >>> graph.add_conditional_edges(
        ...     "processor",
        ...     route_condition,
        ...     {"success": "next_step", "failure": "error_handler"},
        ... )
    """
    # Create edges based on possible returns from condition function
    logger.debug(
        "Node '%s' adding conditional edges with path_map: %s",
        from_node,
        path_map,
    )
    if path_map:
        logger.debug(
            "Node '%s' adding conditional edges with path_map: %s", from_node, path_map
        )
        for condition_result, target_node in path_map.items():
            edge = Edge(from_node, target_node, condition)
            edge.condition_result = condition_result
            self.edges.append(edge)
    else:
        # Single conditional edge
        logger.debug("Node '%s' adding single conditional edge", from_node)
        self.edges.append(Edge(from_node, "", condition))
    return self
add_edge
add_edge(from_node, to_node)

Add a static edge between two nodes.

Creates a direct connection from one node to another. If the source node is START, the target node becomes the entry point for the graph.

Parameters:

Name Type Description Default
from_node str

Name of the source node.

required
to_node str

Name of the target node.

required

Returns:

Name Type Description
StateGraph StateGraph

The graph instance for method chaining.

Example

graph.add_edge("node1", "node2") graph.add_edge(START, "entry_node") # Sets entry point

Source code in pyagenity/graph/state_graph.py
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
def add_edge(
    self,
    from_node: str,
    to_node: str,
) -> "StateGraph":
    """Add a static edge between two nodes.

    Creates a direct connection from one node to another. If the source
    node is START, the target node becomes the entry point for the graph.

    Args:
        from_node: Name of the source node.
        to_node: Name of the target node.

    Returns:
        StateGraph: The graph instance for method chaining.

    Example:
        >>> graph.add_edge("node1", "node2")
        >>> graph.add_edge(START, "entry_node")  # Sets entry point
    """
    logger.debug("Adding edge from '%s' to '%s'", from_node, to_node)
    # Set entry point if edge is from START
    if from_node == START:
        self.entry_point = to_node
        logger.info("Set entry point to '%s'", to_node)
    self.edges.append(Edge(from_node, to_node))
    logger.debug("Added edge (total edges: %d)", len(self.edges))
    return self
add_node
add_node(name_or_func, func=None)

Add a node to the graph.

This method supports two calling patterns: 1. Pass a callable as the first argument (name inferred from function name) 2. Pass a name string and callable/ToolNode as separate arguments

Parameters:

Name Type Description Default
name_or_func str | Callable

Either the node name (str) or a callable function. If callable, the function name will be used as the node name.

required
func Union[Callable, ToolNode, None]

The function or ToolNode to execute. Required if name_or_func is a string, ignored if name_or_func is callable.

None

Returns:

Name Type Description
StateGraph StateGraph

The graph instance for method chaining.

Raises:

Type Description
ValueError

If invalid arguments are provided.

Example
Method 1: Function name inferred

graph.add_node(my_function)

Method 2: Explicit name and function

graph.add_node("process", my_function)

Source code in pyagenity/graph/state_graph.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
def add_node(
    self,
    name_or_func: str | Callable,
    func: Union[Callable, "ToolNode", None] = None,
) -> "StateGraph":
    """Add a node to the graph.

    This method supports two calling patterns:
    1. Pass a callable as the first argument (name inferred from function name)
    2. Pass a name string and callable/ToolNode as separate arguments

    Args:
        name_or_func: Either the node name (str) or a callable function.
            If callable, the function name will be used as the node name.
        func: The function or ToolNode to execute. Required if name_or_func
            is a string, ignored if name_or_func is callable.

    Returns:
        StateGraph: The graph instance for method chaining.

    Raises:
        ValueError: If invalid arguments are provided.

    Example:
        >>> # Method 1: Function name inferred
        >>> graph.add_node(my_function)
        >>> # Method 2: Explicit name and function
        >>> graph.add_node("process", my_function)
    """
    if callable(name_or_func) and func is None:
        # Function passed as first argument
        name = name_or_func.__name__
        func = name_or_func
        logger.debug("Adding node '%s' with inferred name from function", name)
    elif isinstance(name_or_func, str) and (callable(func) or isinstance(func, ToolNode)):
        # Name and function passed separately
        name = name_or_func
        logger.debug(
            "Adding node '%s' with explicit name and %s",
            name,
            "ToolNode" if isinstance(func, ToolNode) else "callable",
        )
    else:
        error_msg = "Invalid arguments for add_node"
        logger.error(error_msg)
        raise ValueError(error_msg)

    self.nodes[name] = Node(name, func)
    logger.info("Added node '%s' to graph (total nodes: %d)", name, len(self.nodes))
    return self
compile
compile(checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())

Compile the graph for execution.

Parameters:

Name Type Description Default
checkpointer BaseCheckpointer[StateT] | None

Checkpointer for state persistence

None
store BaseStore | None

Store for additional data

None
debug

Enable debug mode

required
interrupt_before list[str] | None

List of node names to interrupt before execution

None
interrupt_after list[str] | None

List of node names to interrupt after execution

None
callback_manager CallbackManager

Callback manager for executing hooks

CallbackManager()
Source code in pyagenity/graph/state_graph.py
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
def compile(
    self,
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> "CompiledGraph[StateT]":
    """Compile the graph for execution.

    Args:
        checkpointer: Checkpointer for state persistence
        store: Store for additional data
        debug: Enable debug mode
        interrupt_before: List of node names to interrupt before execution
        interrupt_after: List of node names to interrupt after execution
        callback_manager: Callback manager for executing hooks
    """
    logger.info(
        "Compiling graph with %d nodes, %d edges, entry_point='%s'",
        len(self.nodes),
        len(self.edges),
        self.entry_point,
    )
    logger.debug(
        "Compile options: interrupt_before=%s, interrupt_after=%s",
        interrupt_before,
        interrupt_after,
    )

    if not self.entry_point:
        error_msg = "No entry point set. Use set_entry_point() or add an edge from START."
        logger.error(error_msg)
        raise GraphError(error_msg)

    # Validate graph structure
    logger.debug("Validating graph structure")
    self._validate_graph()
    logger.debug("Graph structure validated successfully")

    # Validate interrupt node names
    interrupt_before = interrupt_before or []
    interrupt_after = interrupt_after or []

    all_interrupt_nodes = set(interrupt_before + interrupt_after)
    invalid_nodes = all_interrupt_nodes - set(self.nodes.keys())
    if invalid_nodes:
        error_msg = f"Invalid interrupt nodes: {invalid_nodes}. Must be existing node names."
        logger.error(error_msg)
        raise GraphError(error_msg)

    self.compiled = True
    logger.info("Graph compilation completed successfully")
    # Import here to avoid circular import at module import time
    # Now update Checkpointer
    if checkpointer is None:
        from pyagenity.checkpointer import InMemoryCheckpointer

        checkpointer = InMemoryCheckpointer[StateT]()
        logger.debug("No checkpointer provided, using InMemoryCheckpointer")

    # Import the CompiledGraph class
    from .compiled_graph import CompiledGraph

    # Setup dependencies
    self._container.bind_instance(
        BaseCheckpointer,
        checkpointer,
        allow_concrete=True,
    )  # not null as we set default
    self._container.bind_instance(
        BaseStore,
        store,
        allow_none=True,
        allow_concrete=True,
    )
    self._container.bind_instance(
        CallbackManager,
        callback_manager,
        allow_concrete=True,
    )  # not null as we set default
    self._container.bind("interrupt_before", interrupt_before)
    self._container.bind("interrupt_after", interrupt_after)
    self._container.bind_instance(StateGraph, self)

    app = CompiledGraph(
        state=self._state,
        interrupt_after=interrupt_after,
        interrupt_before=interrupt_before,
        state_graph=self,
        checkpointer=checkpointer,
        publisher=self._publisher,
        store=store,
        task_manager=self._task_manager,
    )

    self._container.bind(CompiledGraph, app)
    # Compile the Graph, so it will optimize the dependency graph
    self._container.compile()
    return app
set_entry_point
set_entry_point(node_name)

Set the entry point for the graph.

Source code in pyagenity/graph/state_graph.py
381
382
383
384
385
386
def set_entry_point(self, node_name: str) -> "StateGraph":
    """Set the entry point for the graph."""
    self.entry_point = node_name
    self.add_edge(START, node_name)
    logger.info("Set entry point to '%s'", node_name)
    return self
ToolNode

Bases: SchemaMixin, LocalExecMixin, MCPMixin, ComposioMixin, LangChainMixin, KwargsResolverMixin

A unified registry and executor for callable functions from various tool providers.

ToolNode serves as the central hub for managing and executing tools from multiple sources: - Local Python functions - MCP (Model Context Protocol) tools - Composio adapter tools - LangChain tools

The class uses a mixin-based architecture to separate concerns and maintain clean integration with different tool providers. It provides both synchronous and asynchronous execution methods with comprehensive event publishing and error handling.

Attributes:

Name Type Description
_funcs dict[str, Callable]

Dictionary mapping function names to callable functions.

_client Client | None

Optional MCP client for remote tool execution.

_composio ComposioAdapter | None

Optional Composio adapter for external integrations.

_langchain Any | None

Optional LangChain adapter for LangChain tools.

mcp_tools list[str]

List of available MCP tool names.

composio_tools list[str]

List of available Composio tool names.

langchain_tools list[str]

List of available LangChain tool names.

Example
# Define local tools
def weather_tool(location: str) -> str:
    return f"Weather in {location}: Sunny, 25°C"


def calculator(a: int, b: int) -> int:
    return a + b


# Create ToolNode with local functions
tools = ToolNode([weather_tool, calculator])

# Execute a tool
result = await tools.invoke(
    name="weather_tool",
    args={"location": "New York"},
    tool_call_id="call_123",
    config={"user_id": "user1"},
    state=agent_state,
)

Methods:

Name Description
__init__

Initialize ToolNode with functions and optional tool adapters.

all_tools

Get all available tools from all configured providers.

all_tools_sync

Synchronously get all available tools from all configured providers.

get_local_tool

Generate OpenAI-compatible tool definitions for all registered local functions.

invoke

Execute a specific tool by name with the provided arguments.

stream

Execute a tool with streaming support, yielding incremental results.

Source code in pyagenity/graph/tool_node/base.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
class ToolNode(
    SchemaMixin,
    LocalExecMixin,
    MCPMixin,
    ComposioMixin,
    LangChainMixin,
    KwargsResolverMixin,
):
    """A unified registry and executor for callable functions from various tool providers.

    ToolNode serves as the central hub for managing and executing tools from multiple sources:
    - Local Python functions
    - MCP (Model Context Protocol) tools
    - Composio adapter tools
    - LangChain tools

    The class uses a mixin-based architecture to separate concerns and maintain clean
    integration with different tool providers. It provides both synchronous and asynchronous
    execution methods with comprehensive event publishing and error handling.

    Attributes:
        _funcs: Dictionary mapping function names to callable functions.
        _client: Optional MCP client for remote tool execution.
        _composio: Optional Composio adapter for external integrations.
        _langchain: Optional LangChain adapter for LangChain tools.
        mcp_tools: List of available MCP tool names.
        composio_tools: List of available Composio tool names.
        langchain_tools: List of available LangChain tool names.

    Example:
        ```python
        # Define local tools
        def weather_tool(location: str) -> str:
            return f"Weather in {location}: Sunny, 25°C"


        def calculator(a: int, b: int) -> int:
            return a + b


        # Create ToolNode with local functions
        tools = ToolNode([weather_tool, calculator])

        # Execute a tool
        result = await tools.invoke(
            name="weather_tool",
            args={"location": "New York"},
            tool_call_id="call_123",
            config={"user_id": "user1"},
            state=agent_state,
        )
        ```
    """

    def __init__(
        self,
        functions: t.Iterable[t.Callable],
        client: deps.Client | None = None,  # type: ignore
        composio_adapter: ComposioAdapter | None = None,
        langchain_adapter: t.Any | None = None,
    ) -> None:
        """Initialize ToolNode with functions and optional tool adapters.

        Args:
            functions: Iterable of callable functions to register as tools. Each function
                will be registered with its `__name__` as the tool identifier.
            client: Optional MCP (Model Context Protocol) client for remote tool access.
                Requires 'fastmcp' and 'mcp' packages to be installed.
            composio_adapter: Optional Composio adapter for external integrations and
                third-party API access.
            langchain_adapter: Optional LangChain adapter for accessing LangChain tools
                and integrations.

        Raises:
            ImportError: If MCP client is provided but required packages are not installed.
            TypeError: If any item in functions is not callable.

        Note:
            When using MCP client functionality, ensure you have installed the required
            dependencies with: `pip install pyagenity[mcp]`
        """
        logger.info("Initializing ToolNode with %d functions", len(list(functions)))

        if client is not None:
            # Read flags dynamically so tests can patch pyagenity.graph.tool_node.HAS_*
            mod = sys.modules.get("pyagenity.graph.tool_node")
            has_fastmcp = getattr(mod, "HAS_FASTMCP", deps.HAS_FASTMCP) if mod else deps.HAS_FASTMCP
            has_mcp = getattr(mod, "HAS_MCP", deps.HAS_MCP) if mod else deps.HAS_MCP

            if not has_fastmcp or not has_mcp:
                raise ImportError(
                    "MCP client functionality requires 'fastmcp' and 'mcp' packages. "
                    "Install with: pip install pyagenity[mcp]"
                )
            logger.debug("ToolNode initialized with MCP client")

        self._funcs: dict[str, t.Callable] = {}
        self._client: deps.Client | None = client  # type: ignore
        self._composio: ComposioAdapter | None = composio_adapter
        self._langchain: t.Any | None = langchain_adapter

        for fn in functions:
            if not callable(fn):
                raise TypeError("ToolNode only accepts callables")
            self._funcs[fn.__name__] = fn

        self.mcp_tools: list[str] = []
        self.composio_tools: list[str] = []
        self.langchain_tools: list[str] = []

    async def _all_tools_async(self) -> list[dict]:
        tools: list[dict] = self.get_local_tool()
        tools.extend(await self._get_mcp_tool())
        tools.extend(await self._get_composio_tools())
        tools.extend(await self._get_langchain_tools())
        return tools

    async def all_tools(self) -> list[dict]:
        """Get all available tools from all configured providers.

        Retrieves and combines tool definitions from local functions, MCP client,
        Composio adapter, and LangChain adapter. Each tool definition includes
        the function schema with parameters and descriptions.

        Returns:
            List of tool definitions in OpenAI function calling format. Each dict
            contains 'type': 'function' and 'function' with name, description,
            and parameters schema.

        Example:
            ```python
            tools = await tool_node.all_tools()
            # Returns:
            # [
            #   {
            #     "type": "function",
            #     "function": {
            #       "name": "weather_tool",
            #       "description": "Get weather information for a location",
            #       "parameters": {
            #         "type": "object",
            #         "properties": {
            #           "location": {"type": "string"}
            #         },
            #         "required": ["location"]
            #       }
            #     }
            #   }
            # ]
            ```
        """
        return await self._all_tools_async()

    def all_tools_sync(self) -> list[dict]:
        """Synchronously get all available tools from all configured providers.

        This is a synchronous wrapper around the async all_tools() method.
        It uses asyncio.run() to handle async operations from MCP, Composio,
        and LangChain adapters.

        Returns:
            List of tool definitions in OpenAI function calling format.

        Note:
            Prefer using the async `all_tools()` method when possible, especially
            in async contexts, to avoid potential event loop issues.
        """
        tools: list[dict] = self.get_local_tool()
        if self._client:
            result = asyncio.run(self._get_mcp_tool())
            if result:
                tools.extend(result)
        comp = asyncio.run(self._get_composio_tools())
        if comp:
            tools.extend(comp)
        lc = asyncio.run(self._get_langchain_tools())
        if lc:
            tools.extend(lc)
        return tools

    async def invoke(  # noqa: PLR0915
        self,
        name: str,
        args: dict,
        tool_call_id: str,
        config: dict[str, t.Any],
        state: AgentState,
        callback_manager: CallbackManager = Inject[CallbackManager],
    ) -> t.Any:
        """Execute a specific tool by name with the provided arguments.

        This method handles tool execution across all configured providers (local,
        MCP, Composio, LangChain) with comprehensive error handling, event publishing,
        and callback management.

        Args:
            name: The name of the tool to execute.
            args: Dictionary of arguments to pass to the tool function.
            tool_call_id: Unique identifier for this tool execution, used for
                tracking and result correlation.
            config: Configuration dictionary containing execution context and
                user-specific settings.
            state: Current agent state for context-aware tool execution.
            callback_manager: Manager for executing pre/post execution callbacks.
                Injected via dependency injection if not provided.

        Returns:
            Message object containing tool execution results, either successful
            output or error information with appropriate status indicators.

        Raises:
            The method handles all exceptions internally and returns error Messages
            rather than raising exceptions, ensuring robust execution flow.

        Example:
            ```python
            result = await tool_node.invoke(
                name="weather_tool",
                args={"location": "Paris", "units": "metric"},
                tool_call_id="call_abc123",
                config={"user_id": "user1", "session_id": "session1"},
                state=current_agent_state,
            )

            # result is a Message with tool execution results
            print(result.content)  # Tool output or error information
            ```

        Note:
            The method publishes execution events throughout the process for
            monitoring and debugging purposes. Tool execution is routed based
            on tool provider precedence: MCP → Composio → LangChain → Local.
        """
        logger.info("Executing tool '%s' with %d arguments", name, len(args))
        logger.debug("Tool arguments: %s", args)

        event = EventModel.default(
            config,
            data={"args": args, "tool_call_id": tool_call_id, "function_name": name},
            content_type=[ContentType.TOOL_CALL],
            event=Event.TOOL_EXECUTION,
        )
        event.node_name = name
        # Attach structured tool call block
        with contextlib.suppress(Exception):
            event.content_blocks = [ToolCallBlock(id=tool_call_id, name=name, args=args)]
        publish_event(event)

        if name in self.mcp_tools:
            event.metadata["is_mcp"] = True
            publish_event(event)
            res = await self._mcp_execute(
                name,
                args,
                tool_call_id,
                config,
                callback_manager,
            )
            event.data["message"] = res.model_dump()
            # Attach tool result block mirroring the tool output
            with contextlib.suppress(Exception):
                event.content_blocks = [
                    ToolResultBlock(call_id=tool_call_id, output=res.model_dump())
                ]
            event.event_type = EventType.END
            event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
            publish_event(event)
            return res

        if name in self.composio_tools:
            event.metadata["is_composio"] = True
            publish_event(event)
            res = await self._composio_execute(
                name,
                args,
                tool_call_id,
                config,
                callback_manager,
            )
            event.data["message"] = res.model_dump()
            with contextlib.suppress(Exception):
                event.content_blocks = [
                    ToolResultBlock(call_id=tool_call_id, output=res.model_dump())
                ]
            event.event_type = EventType.END
            event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
            publish_event(event)
            return res

        if name in self.langchain_tools:
            event.metadata["is_langchain"] = True
            publish_event(event)
            res = await self._langchain_execute(
                name,
                args,
                tool_call_id,
                config,
                callback_manager,
            )
            event.data["message"] = res.model_dump()
            with contextlib.suppress(Exception):
                event.content_blocks = [
                    ToolResultBlock(call_id=tool_call_id, output=res.model_dump())
                ]
            event.event_type = EventType.END
            event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
            publish_event(event)
            return res

        if name in self._funcs:
            event.metadata["is_mcp"] = False
            publish_event(event)
            res = await self._internal_execute(
                name,
                args,
                tool_call_id,
                config,
                state,
                callback_manager,
            )
            event.data["message"] = res.model_dump()
            with contextlib.suppress(Exception):
                event.content_blocks = [
                    ToolResultBlock(call_id=tool_call_id, output=res.model_dump())
                ]
            event.event_type = EventType.END
            event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
            publish_event(event)
            return res

        error_msg = f"Tool '{name}' not found."
        event.data["error"] = error_msg
        event.event_type = EventType.ERROR
        event.content_type = [ContentType.TOOL_RESULT, ContentType.ERROR]
        publish_event(event)
        return Message.tool_message(
            content=[
                ErrorBlock(message=error_msg),
                ToolResultBlock(
                    call_id=tool_call_id,
                    output=error_msg,
                    is_error=True,
                    status="failed",
                ),
            ],
        )

    async def stream(  # noqa: PLR0915
        self,
        name: str,
        args: dict,
        tool_call_id: str,
        config: dict[str, t.Any],
        state: AgentState,
        callback_manager: CallbackManager = Inject[CallbackManager],
    ) -> t.AsyncIterator[Message]:
        """Execute a tool with streaming support, yielding incremental results.

        Similar to invoke() but designed for tools that can provide streaming responses
        or when you want to process results as they become available. Currently,
        most tool providers return complete results, so this method typically yields
        a single Message with the full result.

        Args:
            name: The name of the tool to execute.
            args: Dictionary of arguments to pass to the tool function.
            tool_call_id: Unique identifier for this tool execution.
            config: Configuration dictionary containing execution context.
            state: Current agent state for context-aware tool execution.
            callback_manager: Manager for executing pre/post execution callbacks.

        Yields:
            Message objects containing tool execution results or status updates.
            For most tools, this will yield a single complete result Message.

        Example:
            ```python
            async for message in tool_node.stream(
                name="data_processor",
                args={"dataset": "large_data.csv"},
                tool_call_id="call_stream123",
                config={"user_id": "user1"},
                state=current_state,
            ):
                print(f"Received: {message.content}")
                # Process each streamed result
            ```

        Note:
            The streaming interface is designed for future expansion where tools
            may provide true streaming responses. Currently, it provides a
            consistent async iterator interface over tool results.
        """
        logger.info("Executing tool '%s' with %d arguments", name, len(args))
        logger.debug("Tool arguments: %s", args)
        event = EventModel.default(
            config,
            data={"args": args, "tool_call_id": tool_call_id, "function_name": name},
            content_type=[ContentType.TOOL_CALL],
            event=Event.TOOL_EXECUTION,
        )
        event.node_name = "ToolNode"
        with contextlib.suppress(Exception):
            event.content_blocks = [ToolCallBlock(id=tool_call_id, name=name, args=args)]

        if name in self.mcp_tools:
            event.metadata["function_type"] = "mcp"
            publish_event(event)
            message = await self._mcp_execute(
                name,
                args,
                tool_call_id,
                config,
                callback_manager,
            )
            event.data["message"] = message.model_dump()
            with contextlib.suppress(Exception):
                event.content_blocks = [
                    ToolResultBlock(call_id=tool_call_id, output=message.model_dump())
                ]
            event.event_type = EventType.END
            event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
            publish_event(event)
            yield message
            return

        if name in self.composio_tools:
            event.metadata["function_type"] = "composio"
            publish_event(event)
            message = await self._composio_execute(
                name,
                args,
                tool_call_id,
                config,
                callback_manager,
            )
            event.data["message"] = message.model_dump()
            with contextlib.suppress(Exception):
                event.content_blocks = [
                    ToolResultBlock(call_id=tool_call_id, output=message.model_dump())
                ]
            event.event_type = EventType.END
            event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
            publish_event(event)
            yield message
            return

        if name in self.langchain_tools:
            event.metadata["function_type"] = "langchain"
            publish_event(event)
            message = await self._langchain_execute(
                name,
                args,
                tool_call_id,
                config,
                callback_manager,
            )
            event.data["message"] = message.model_dump()
            with contextlib.suppress(Exception):
                event.content_blocks = [
                    ToolResultBlock(call_id=tool_call_id, output=message.model_dump())
                ]
            event.event_type = EventType.END
            event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
            publish_event(event)
            yield message
            return

        if name in self._funcs:
            event.metadata["function_type"] = "internal"
            publish_event(event)

            result = await self._internal_execute(
                name,
                args,
                tool_call_id,
                config,
                state,
                callback_manager,
            )
            event.data["message"] = result.model_dump()
            with contextlib.suppress(Exception):
                event.content_blocks = [
                    ToolResultBlock(call_id=tool_call_id, output=result.model_dump())
                ]
            event.event_type = EventType.END
            event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
            publish_event(event)
            yield result
            return

        error_msg = f"Tool '{name}' not found."
        event.data["error"] = error_msg
        event.event_type = EventType.ERROR
        event.content_type = [ContentType.TOOL_RESULT, ContentType.ERROR]
        publish_event(event)

        yield Message.tool_message(
            content=[
                ErrorBlock(message=error_msg),
                ToolResultBlock(
                    call_id=tool_call_id,
                    output=error_msg,
                    is_error=True,
                    status="failed",
                ),
            ],
        )
Attributes
composio_tools instance-attribute
composio_tools = []
langchain_tools instance-attribute
langchain_tools = []
mcp_tools instance-attribute
mcp_tools = []
Functions
__init__
__init__(functions, client=None, composio_adapter=None, langchain_adapter=None)

Initialize ToolNode with functions and optional tool adapters.

Parameters:

Name Type Description Default
functions Iterable[Callable]

Iterable of callable functions to register as tools. Each function will be registered with its __name__ as the tool identifier.

required
client Client | None

Optional MCP (Model Context Protocol) client for remote tool access. Requires 'fastmcp' and 'mcp' packages to be installed.

None
composio_adapter ComposioAdapter | None

Optional Composio adapter for external integrations and third-party API access.

None
langchain_adapter Any | None

Optional LangChain adapter for accessing LangChain tools and integrations.

None

Raises:

Type Description
ImportError

If MCP client is provided but required packages are not installed.

TypeError

If any item in functions is not callable.

Note

When using MCP client functionality, ensure you have installed the required dependencies with: pip install pyagenity[mcp]

Source code in pyagenity/graph/tool_node/base.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def __init__(
    self,
    functions: t.Iterable[t.Callable],
    client: deps.Client | None = None,  # type: ignore
    composio_adapter: ComposioAdapter | None = None,
    langchain_adapter: t.Any | None = None,
) -> None:
    """Initialize ToolNode with functions and optional tool adapters.

    Args:
        functions: Iterable of callable functions to register as tools. Each function
            will be registered with its `__name__` as the tool identifier.
        client: Optional MCP (Model Context Protocol) client for remote tool access.
            Requires 'fastmcp' and 'mcp' packages to be installed.
        composio_adapter: Optional Composio adapter for external integrations and
            third-party API access.
        langchain_adapter: Optional LangChain adapter for accessing LangChain tools
            and integrations.

    Raises:
        ImportError: If MCP client is provided but required packages are not installed.
        TypeError: If any item in functions is not callable.

    Note:
        When using MCP client functionality, ensure you have installed the required
        dependencies with: `pip install pyagenity[mcp]`
    """
    logger.info("Initializing ToolNode with %d functions", len(list(functions)))

    if client is not None:
        # Read flags dynamically so tests can patch pyagenity.graph.tool_node.HAS_*
        mod = sys.modules.get("pyagenity.graph.tool_node")
        has_fastmcp = getattr(mod, "HAS_FASTMCP", deps.HAS_FASTMCP) if mod else deps.HAS_FASTMCP
        has_mcp = getattr(mod, "HAS_MCP", deps.HAS_MCP) if mod else deps.HAS_MCP

        if not has_fastmcp or not has_mcp:
            raise ImportError(
                "MCP client functionality requires 'fastmcp' and 'mcp' packages. "
                "Install with: pip install pyagenity[mcp]"
            )
        logger.debug("ToolNode initialized with MCP client")

    self._funcs: dict[str, t.Callable] = {}
    self._client: deps.Client | None = client  # type: ignore
    self._composio: ComposioAdapter | None = composio_adapter
    self._langchain: t.Any | None = langchain_adapter

    for fn in functions:
        if not callable(fn):
            raise TypeError("ToolNode only accepts callables")
        self._funcs[fn.__name__] = fn

    self.mcp_tools: list[str] = []
    self.composio_tools: list[str] = []
    self.langchain_tools: list[str] = []
all_tools async
all_tools()

Get all available tools from all configured providers.

Retrieves and combines tool definitions from local functions, MCP client, Composio adapter, and LangChain adapter. Each tool definition includes the function schema with parameters and descriptions.

Returns:

Type Description
list[dict]

List of tool definitions in OpenAI function calling format. Each dict

list[dict]

contains 'type': 'function' and 'function' with name, description,

list[dict]

and parameters schema.

Example
tools = await tool_node.all_tools()
# Returns:
# [
#   {
#     "type": "function",
#     "function": {
#       "name": "weather_tool",
#       "description": "Get weather information for a location",
#       "parameters": {
#         "type": "object",
#         "properties": {
#           "location": {"type": "string"}
#         },
#         "required": ["location"]
#       }
#     }
#   }
# ]
Source code in pyagenity/graph/tool_node/base.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
async def all_tools(self) -> list[dict]:
    """Get all available tools from all configured providers.

    Retrieves and combines tool definitions from local functions, MCP client,
    Composio adapter, and LangChain adapter. Each tool definition includes
    the function schema with parameters and descriptions.

    Returns:
        List of tool definitions in OpenAI function calling format. Each dict
        contains 'type': 'function' and 'function' with name, description,
        and parameters schema.

    Example:
        ```python
        tools = await tool_node.all_tools()
        # Returns:
        # [
        #   {
        #     "type": "function",
        #     "function": {
        #       "name": "weather_tool",
        #       "description": "Get weather information for a location",
        #       "parameters": {
        #         "type": "object",
        #         "properties": {
        #           "location": {"type": "string"}
        #         },
        #         "required": ["location"]
        #       }
        #     }
        #   }
        # ]
        ```
    """
    return await self._all_tools_async()
all_tools_sync
all_tools_sync()

Synchronously get all available tools from all configured providers.

This is a synchronous wrapper around the async all_tools() method. It uses asyncio.run() to handle async operations from MCP, Composio, and LangChain adapters.

Returns:

Type Description
list[dict]

List of tool definitions in OpenAI function calling format.

Note

Prefer using the async all_tools() method when possible, especially in async contexts, to avoid potential event loop issues.

Source code in pyagenity/graph/tool_node/base.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def all_tools_sync(self) -> list[dict]:
    """Synchronously get all available tools from all configured providers.

    This is a synchronous wrapper around the async all_tools() method.
    It uses asyncio.run() to handle async operations from MCP, Composio,
    and LangChain adapters.

    Returns:
        List of tool definitions in OpenAI function calling format.

    Note:
        Prefer using the async `all_tools()` method when possible, especially
        in async contexts, to avoid potential event loop issues.
    """
    tools: list[dict] = self.get_local_tool()
    if self._client:
        result = asyncio.run(self._get_mcp_tool())
        if result:
            tools.extend(result)
    comp = asyncio.run(self._get_composio_tools())
    if comp:
        tools.extend(comp)
    lc = asyncio.run(self._get_langchain_tools())
    if lc:
        tools.extend(lc)
    return tools
get_local_tool
get_local_tool()

Generate OpenAI-compatible tool definitions for all registered local functions.

Inspects all registered functions in _funcs and automatically generates tool schemas by analyzing function signatures, type annotations, and docstrings. Excludes injectable parameters that are provided by the framework.

Returns:

Type Description
list[dict]

List of tool definitions in OpenAI function calling format. Each

list[dict]

definition includes the function name, description (from docstring),

list[dict]

and complete parameter schema with types and required fields.

Example

For a function:

def calculate(a: int, b: int, operation: str = "add") -> int:
    '''Perform arithmetic calculation.'''
    return a + b if operation == "add" else a - b

Returns:

[
    {
        "type": "function",
        "function": {
            "name": "calculate",
            "description": "Perform arithmetic calculation.",
            "parameters": {
                "type": "object",
                "properties": {
                    "a": {"type": "integer"},
                    "b": {"type": "integer"},
                    "operation": {"type": "string", "default": "add"},
                },
                "required": ["a", "b"],
            },
        },
    }
]

Note

Parameters listed in INJECTABLE_PARAMS (like 'state', 'config', 'tool_call_id') are automatically excluded from the generated schema as they are provided by the framework during execution.

Source code in pyagenity/graph/tool_node/schema.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def get_local_tool(self) -> list[dict]:
    """Generate OpenAI-compatible tool definitions for all registered local functions.

    Inspects all registered functions in _funcs and automatically generates
    tool schemas by analyzing function signatures, type annotations, and docstrings.
    Excludes injectable parameters that are provided by the framework.

    Returns:
        List of tool definitions in OpenAI function calling format. Each
        definition includes the function name, description (from docstring),
        and complete parameter schema with types and required fields.

    Example:
        For a function:
        ```python
        def calculate(a: int, b: int, operation: str = "add") -> int:
            '''Perform arithmetic calculation.'''
            return a + b if operation == "add" else a - b
        ```

        Returns:
        ```python
        [
            {
                "type": "function",
                "function": {
                    "name": "calculate",
                    "description": "Perform arithmetic calculation.",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "a": {"type": "integer"},
                            "b": {"type": "integer"},
                            "operation": {"type": "string", "default": "add"},
                        },
                        "required": ["a", "b"],
                    },
                },
            }
        ]
        ```

    Note:
        Parameters listed in INJECTABLE_PARAMS (like 'state', 'config',
        'tool_call_id') are automatically excluded from the generated schema
        as they are provided by the framework during execution.
    """
    tools: list[dict] = []
    for name, fn in self._funcs.items():
        sig = inspect.signature(fn)
        params_schema: dict = {"type": "object", "properties": {}, "required": []}

        for p_name, p in sig.parameters.items():
            if p.kind in (
                inspect.Parameter.VAR_POSITIONAL,
                inspect.Parameter.VAR_KEYWORD,
            ):
                continue

            if p_name in INJECTABLE_PARAMS:
                continue

            annotation = p.annotation if p.annotation is not inspect._empty else str
            prop = SchemaMixin._annotation_to_schema(annotation, p.default)
            params_schema["properties"][p_name] = prop

            if p.default is inspect._empty:
                params_schema["required"].append(p_name)

        if not params_schema["required"]:
            params_schema.pop("required")

        description = inspect.getdoc(fn) or "No description provided."

        # provider = getattr(fn, "_py_tool_provider", None)
        # tags = getattr(fn, "_py_tool_tags", None)
        # capabilities = getattr(fn, "_py_tool_capabilities", None)

        entry = {
            "type": "function",
            "function": {
                "name": name,
                "description": description,
                "parameters": params_schema,
            },
        }
        # meta: dict[str, t.Any] = {}
        # if provider:
        #     meta["provider"] = provider
        # if tags:
        #     meta["tags"] = tags
        # if capabilities:
        #     meta["capabilities"] = capabilities
        # if meta:
        #     entry["x-pyagenity"] = meta

        tools.append(entry)

    return tools
invoke async
invoke(name, args, tool_call_id, config, state, callback_manager=Inject[CallbackManager])

Execute a specific tool by name with the provided arguments.

This method handles tool execution across all configured providers (local, MCP, Composio, LangChain) with comprehensive error handling, event publishing, and callback management.

Parameters:

Name Type Description Default
name str

The name of the tool to execute.

required
args dict

Dictionary of arguments to pass to the tool function.

required
tool_call_id str

Unique identifier for this tool execution, used for tracking and result correlation.

required
config dict[str, Any]

Configuration dictionary containing execution context and user-specific settings.

required
state AgentState

Current agent state for context-aware tool execution.

required
callback_manager CallbackManager

Manager for executing pre/post execution callbacks. Injected via dependency injection if not provided.

Inject[CallbackManager]

Returns:

Type Description
Any

Message object containing tool execution results, either successful

Any

output or error information with appropriate status indicators.

Example
result = await tool_node.invoke(
    name="weather_tool",
    args={"location": "Paris", "units": "metric"},
    tool_call_id="call_abc123",
    config={"user_id": "user1", "session_id": "session1"},
    state=current_agent_state,
)

# result is a Message with tool execution results
print(result.content)  # Tool output or error information
Note

The method publishes execution events throughout the process for monitoring and debugging purposes. Tool execution is routed based on tool provider precedence: MCP → Composio → LangChain → Local.

Source code in pyagenity/graph/tool_node/base.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
async def invoke(  # noqa: PLR0915
    self,
    name: str,
    args: dict,
    tool_call_id: str,
    config: dict[str, t.Any],
    state: AgentState,
    callback_manager: CallbackManager = Inject[CallbackManager],
) -> t.Any:
    """Execute a specific tool by name with the provided arguments.

    This method handles tool execution across all configured providers (local,
    MCP, Composio, LangChain) with comprehensive error handling, event publishing,
    and callback management.

    Args:
        name: The name of the tool to execute.
        args: Dictionary of arguments to pass to the tool function.
        tool_call_id: Unique identifier for this tool execution, used for
            tracking and result correlation.
        config: Configuration dictionary containing execution context and
            user-specific settings.
        state: Current agent state for context-aware tool execution.
        callback_manager: Manager for executing pre/post execution callbacks.
            Injected via dependency injection if not provided.

    Returns:
        Message object containing tool execution results, either successful
        output or error information with appropriate status indicators.

    Raises:
        The method handles all exceptions internally and returns error Messages
        rather than raising exceptions, ensuring robust execution flow.

    Example:
        ```python
        result = await tool_node.invoke(
            name="weather_tool",
            args={"location": "Paris", "units": "metric"},
            tool_call_id="call_abc123",
            config={"user_id": "user1", "session_id": "session1"},
            state=current_agent_state,
        )

        # result is a Message with tool execution results
        print(result.content)  # Tool output or error information
        ```

    Note:
        The method publishes execution events throughout the process for
        monitoring and debugging purposes. Tool execution is routed based
        on tool provider precedence: MCP → Composio → LangChain → Local.
    """
    logger.info("Executing tool '%s' with %d arguments", name, len(args))
    logger.debug("Tool arguments: %s", args)

    event = EventModel.default(
        config,
        data={"args": args, "tool_call_id": tool_call_id, "function_name": name},
        content_type=[ContentType.TOOL_CALL],
        event=Event.TOOL_EXECUTION,
    )
    event.node_name = name
    # Attach structured tool call block
    with contextlib.suppress(Exception):
        event.content_blocks = [ToolCallBlock(id=tool_call_id, name=name, args=args)]
    publish_event(event)

    if name in self.mcp_tools:
        event.metadata["is_mcp"] = True
        publish_event(event)
        res = await self._mcp_execute(
            name,
            args,
            tool_call_id,
            config,
            callback_manager,
        )
        event.data["message"] = res.model_dump()
        # Attach tool result block mirroring the tool output
        with contextlib.suppress(Exception):
            event.content_blocks = [
                ToolResultBlock(call_id=tool_call_id, output=res.model_dump())
            ]
        event.event_type = EventType.END
        event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
        publish_event(event)
        return res

    if name in self.composio_tools:
        event.metadata["is_composio"] = True
        publish_event(event)
        res = await self._composio_execute(
            name,
            args,
            tool_call_id,
            config,
            callback_manager,
        )
        event.data["message"] = res.model_dump()
        with contextlib.suppress(Exception):
            event.content_blocks = [
                ToolResultBlock(call_id=tool_call_id, output=res.model_dump())
            ]
        event.event_type = EventType.END
        event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
        publish_event(event)
        return res

    if name in self.langchain_tools:
        event.metadata["is_langchain"] = True
        publish_event(event)
        res = await self._langchain_execute(
            name,
            args,
            tool_call_id,
            config,
            callback_manager,
        )
        event.data["message"] = res.model_dump()
        with contextlib.suppress(Exception):
            event.content_blocks = [
                ToolResultBlock(call_id=tool_call_id, output=res.model_dump())
            ]
        event.event_type = EventType.END
        event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
        publish_event(event)
        return res

    if name in self._funcs:
        event.metadata["is_mcp"] = False
        publish_event(event)
        res = await self._internal_execute(
            name,
            args,
            tool_call_id,
            config,
            state,
            callback_manager,
        )
        event.data["message"] = res.model_dump()
        with contextlib.suppress(Exception):
            event.content_blocks = [
                ToolResultBlock(call_id=tool_call_id, output=res.model_dump())
            ]
        event.event_type = EventType.END
        event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
        publish_event(event)
        return res

    error_msg = f"Tool '{name}' not found."
    event.data["error"] = error_msg
    event.event_type = EventType.ERROR
    event.content_type = [ContentType.TOOL_RESULT, ContentType.ERROR]
    publish_event(event)
    return Message.tool_message(
        content=[
            ErrorBlock(message=error_msg),
            ToolResultBlock(
                call_id=tool_call_id,
                output=error_msg,
                is_error=True,
                status="failed",
            ),
        ],
    )
stream async
stream(name, args, tool_call_id, config, state, callback_manager=Inject[CallbackManager])

Execute a tool with streaming support, yielding incremental results.

Similar to invoke() but designed for tools that can provide streaming responses or when you want to process results as they become available. Currently, most tool providers return complete results, so this method typically yields a single Message with the full result.

Parameters:

Name Type Description Default
name str

The name of the tool to execute.

required
args dict

Dictionary of arguments to pass to the tool function.

required
tool_call_id str

Unique identifier for this tool execution.

required
config dict[str, Any]

Configuration dictionary containing execution context.

required
state AgentState

Current agent state for context-aware tool execution.

required
callback_manager CallbackManager

Manager for executing pre/post execution callbacks.

Inject[CallbackManager]

Yields:

Type Description
AsyncIterator[Message]

Message objects containing tool execution results or status updates.

AsyncIterator[Message]

For most tools, this will yield a single complete result Message.

Example
async for message in tool_node.stream(
    name="data_processor",
    args={"dataset": "large_data.csv"},
    tool_call_id="call_stream123",
    config={"user_id": "user1"},
    state=current_state,
):
    print(f"Received: {message.content}")
    # Process each streamed result
Note

The streaming interface is designed for future expansion where tools may provide true streaming responses. Currently, it provides a consistent async iterator interface over tool results.

Source code in pyagenity/graph/tool_node/base.py
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
async def stream(  # noqa: PLR0915
    self,
    name: str,
    args: dict,
    tool_call_id: str,
    config: dict[str, t.Any],
    state: AgentState,
    callback_manager: CallbackManager = Inject[CallbackManager],
) -> t.AsyncIterator[Message]:
    """Execute a tool with streaming support, yielding incremental results.

    Similar to invoke() but designed for tools that can provide streaming responses
    or when you want to process results as they become available. Currently,
    most tool providers return complete results, so this method typically yields
    a single Message with the full result.

    Args:
        name: The name of the tool to execute.
        args: Dictionary of arguments to pass to the tool function.
        tool_call_id: Unique identifier for this tool execution.
        config: Configuration dictionary containing execution context.
        state: Current agent state for context-aware tool execution.
        callback_manager: Manager for executing pre/post execution callbacks.

    Yields:
        Message objects containing tool execution results or status updates.
        For most tools, this will yield a single complete result Message.

    Example:
        ```python
        async for message in tool_node.stream(
            name="data_processor",
            args={"dataset": "large_data.csv"},
            tool_call_id="call_stream123",
            config={"user_id": "user1"},
            state=current_state,
        ):
            print(f"Received: {message.content}")
            # Process each streamed result
        ```

    Note:
        The streaming interface is designed for future expansion where tools
        may provide true streaming responses. Currently, it provides a
        consistent async iterator interface over tool results.
    """
    logger.info("Executing tool '%s' with %d arguments", name, len(args))
    logger.debug("Tool arguments: %s", args)
    event = EventModel.default(
        config,
        data={"args": args, "tool_call_id": tool_call_id, "function_name": name},
        content_type=[ContentType.TOOL_CALL],
        event=Event.TOOL_EXECUTION,
    )
    event.node_name = "ToolNode"
    with contextlib.suppress(Exception):
        event.content_blocks = [ToolCallBlock(id=tool_call_id, name=name, args=args)]

    if name in self.mcp_tools:
        event.metadata["function_type"] = "mcp"
        publish_event(event)
        message = await self._mcp_execute(
            name,
            args,
            tool_call_id,
            config,
            callback_manager,
        )
        event.data["message"] = message.model_dump()
        with contextlib.suppress(Exception):
            event.content_blocks = [
                ToolResultBlock(call_id=tool_call_id, output=message.model_dump())
            ]
        event.event_type = EventType.END
        event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
        publish_event(event)
        yield message
        return

    if name in self.composio_tools:
        event.metadata["function_type"] = "composio"
        publish_event(event)
        message = await self._composio_execute(
            name,
            args,
            tool_call_id,
            config,
            callback_manager,
        )
        event.data["message"] = message.model_dump()
        with contextlib.suppress(Exception):
            event.content_blocks = [
                ToolResultBlock(call_id=tool_call_id, output=message.model_dump())
            ]
        event.event_type = EventType.END
        event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
        publish_event(event)
        yield message
        return

    if name in self.langchain_tools:
        event.metadata["function_type"] = "langchain"
        publish_event(event)
        message = await self._langchain_execute(
            name,
            args,
            tool_call_id,
            config,
            callback_manager,
        )
        event.data["message"] = message.model_dump()
        with contextlib.suppress(Exception):
            event.content_blocks = [
                ToolResultBlock(call_id=tool_call_id, output=message.model_dump())
            ]
        event.event_type = EventType.END
        event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
        publish_event(event)
        yield message
        return

    if name in self._funcs:
        event.metadata["function_type"] = "internal"
        publish_event(event)

        result = await self._internal_execute(
            name,
            args,
            tool_call_id,
            config,
            state,
            callback_manager,
        )
        event.data["message"] = result.model_dump()
        with contextlib.suppress(Exception):
            event.content_blocks = [
                ToolResultBlock(call_id=tool_call_id, output=result.model_dump())
            ]
        event.event_type = EventType.END
        event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
        publish_event(event)
        yield result
        return

    error_msg = f"Tool '{name}' not found."
    event.data["error"] = error_msg
    event.event_type = EventType.ERROR
    event.content_type = [ContentType.TOOL_RESULT, ContentType.ERROR]
    publish_event(event)

    yield Message.tool_message(
        content=[
            ErrorBlock(message=error_msg),
            ToolResultBlock(
                call_id=tool_call_id,
                output=error_msg,
                is_error=True,
                status="failed",
            ),
        ],
    )

Modules

compiled_graph

Classes:

Name Description
CompiledGraph

A fully compiled and executable graph ready for workflow execution.

Attributes:

Name Type Description
StateT
logger
Attributes
StateT module-attribute
StateT = TypeVar('StateT', bound=AgentState)
logger module-attribute
logger = getLogger(__name__)
Classes
CompiledGraph

A fully compiled and executable graph ready for workflow execution.

CompiledGraph represents the final executable form of a StateGraph after compilation. It encapsulates all the execution logic, handlers, and services needed to run agent workflows. The graph supports both synchronous and asynchronous execution with comprehensive state management, checkpointing, event publishing, and streaming capabilities.

This class is generic over state types to support custom AgentState subclasses, ensuring type safety throughout the execution process.

Key Features: - Synchronous and asynchronous execution methods - Real-time streaming with incremental results - State persistence and checkpointing - Interrupt and resume capabilities - Event publishing for monitoring and debugging - Background task management - Graceful error handling and recovery

Attributes:

Name Type Description
_state

The initial/template state for graph executions.

_invoke_handler InvokeHandler[StateT]

Handler for non-streaming graph execution.

_stream_handler StreamHandler[StateT]

Handler for streaming graph execution.

_checkpointer BaseCheckpointer[StateT] | None

Optional state persistence backend.

_publisher BasePublisher | None

Optional event publishing backend.

_store BaseStore | None

Optional data storage backend.

_state_graph StateGraph[StateT]

Reference to the source StateGraph.

_interrupt_before list[str]

Nodes where execution should pause before execution.

_interrupt_after list[str]

Nodes where execution should pause after execution.

_task_manager

Manager for background async tasks.

Example
# After building and compiling a StateGraph
compiled = graph.compile()

# Synchronous execution
result = compiled.invoke({"messages": [Message.text_message("Hello")]})

# Asynchronous execution with streaming
async for chunk in compiled.astream({"messages": [message]}):
    print(f"Streamed: {chunk.content}")

# Graceful cleanup
await compiled.aclose()
Note

CompiledGraph instances should be properly closed using aclose() to release resources like database connections, background tasks, and event publishers.

Methods:

Name Description
__init__
aclose

Close the graph and release any resources.

ainvoke

Execute the graph asynchronously.

astop

Request the current graph execution to stop (async).

astream

Execute the graph asynchronously with streaming support.

generate_graph

Generate the graph representation.

invoke

Execute the graph synchronously and return the final results.

stop

Request the current graph execution to stop (sync helper).

stream

Execute the graph synchronously with streaming support.

Source code in pyagenity/graph/compiled_graph.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
class CompiledGraph[StateT: AgentState]:
    """A fully compiled and executable graph ready for workflow execution.

    CompiledGraph represents the final executable form of a StateGraph after compilation.
    It encapsulates all the execution logic, handlers, and services needed to run
    agent workflows. The graph supports both synchronous and asynchronous execution
    with comprehensive state management, checkpointing, event publishing, and
    streaming capabilities.

    This class is generic over state types to support custom AgentState subclasses,
    ensuring type safety throughout the execution process.

    Key Features:
    - Synchronous and asynchronous execution methods
    - Real-time streaming with incremental results
    - State persistence and checkpointing
    - Interrupt and resume capabilities
    - Event publishing for monitoring and debugging
    - Background task management
    - Graceful error handling and recovery

    Attributes:
        _state: The initial/template state for graph executions.
        _invoke_handler: Handler for non-streaming graph execution.
        _stream_handler: Handler for streaming graph execution.
        _checkpointer: Optional state persistence backend.
        _publisher: Optional event publishing backend.
        _store: Optional data storage backend.
        _state_graph: Reference to the source StateGraph.
        _interrupt_before: Nodes where execution should pause before execution.
        _interrupt_after: Nodes where execution should pause after execution.
        _task_manager: Manager for background async tasks.

    Example:
        ```python
        # After building and compiling a StateGraph
        compiled = graph.compile()

        # Synchronous execution
        result = compiled.invoke({"messages": [Message.text_message("Hello")]})

        # Asynchronous execution with streaming
        async for chunk in compiled.astream({"messages": [message]}):
            print(f"Streamed: {chunk.content}")

        # Graceful cleanup
        await compiled.aclose()
        ```

    Note:
        CompiledGraph instances should be properly closed using aclose() to
        release resources like database connections, background tasks, and
        event publishers.
    """

    def __init__(
        self,
        state: StateT,
        checkpointer: BaseCheckpointer[StateT] | None,
        publisher: BasePublisher | None,
        store: BaseStore | None,
        state_graph: StateGraph[StateT],
        interrupt_before: list[str],
        interrupt_after: list[str],
        task_manager: BackgroundTaskManager,
    ):
        logger.info(
            f"Initializing CompiledGraph with nodes: {list(state_graph.nodes.keys())}",
        )

        # Save initial state
        self._state = state

        # create handlers
        self._invoke_handler: InvokeHandler[StateT] = InvokeHandler[StateT](
            nodes=state_graph.nodes,  # type: ignore
            edges=state_graph.edges,  # type: ignore
        )
        self._stream_handler: StreamHandler[StateT] = StreamHandler[StateT](
            nodes=state_graph.nodes,  # type: ignore
            edges=state_graph.edges,  # type: ignore
        )

        self._checkpointer: BaseCheckpointer[StateT] | None = checkpointer
        self._publisher: BasePublisher | None = publisher
        self._store: BaseStore | None = store
        self._state_graph: StateGraph[StateT] = state_graph
        self._interrupt_before: list[str] = interrupt_before
        self._interrupt_after: list[str] = interrupt_after
        # generate task manager
        self._task_manager = task_manager

    def _prepare_config(
        self,
        config: dict[str, Any] | None,
        is_stream: bool = False,
    ) -> dict[str, Any]:
        cfg = config or {}
        if "is_stream" not in cfg:
            cfg["is_stream"] = is_stream
        if "user_id" not in cfg:
            cfg["user_id"] = "test-user-id"  # mock user id
        if "run_id" not in cfg:
            cfg["run_id"] = InjectQ.get_instance().try_get("generated_id") or str(uuid4())

        if "timestamp" not in cfg:
            cfg["timestamp"] = datetime.datetime.now().isoformat()

        return cfg

    def invoke(
        self,
        input_data: dict[str, Any],
        config: dict[str, Any] | None = None,
        response_granularity: ResponseGranularity = ResponseGranularity.LOW,
    ) -> dict[str, Any]:
        """Execute the graph synchronously and return the final results.

        Runs the complete graph workflow from start to finish, handling state
        management, node execution, and result formatting. This method automatically
        detects whether to start a fresh execution or resume from an interrupted state.

        The execution is synchronous but internally uses async operations, making it
        suitable for use in non-async contexts while still benefiting from async
        capabilities for I/O operations.

        Args:
            input_data: Input dictionary for graph execution. For new executions,
                should contain 'messages' key with list of initial messages.
                For resumed executions, can contain additional data to merge.
            config: Optional configuration dictionary containing execution settings:
                - user_id: Identifier for the user/session
                - thread_id: Unique identifier for this execution thread
                - run_id: Unique identifier for this specific run
                - recursion_limit: Maximum steps before stopping (default: 25)
            response_granularity: Level of detail in the response:
                - LOW: Returns only messages (default)
                - PARTIAL: Returns context, summary, and messages
                - FULL: Returns complete state and messages

        Returns:
            Dictionary containing execution results formatted according to the
            specified granularity level. Always includes execution messages
            and may include additional state information.

        Raises:
            ValueError: If input_data is invalid for new execution.
            GraphRecursionError: If execution exceeds recursion limit.
            Various exceptions: Depending on node execution failures.

        Example:
            ```python
            # Basic execution
            result = compiled.invoke({"messages": [Message.text_message("Process this data")]})
            print(result["messages"])  # Final execution messages

            # With configuration and full details
            result = compiled.invoke(
                input_data={"messages": [message]},
                config={"user_id": "user123", "thread_id": "session456", "recursion_limit": 50},
                response_granularity=ResponseGranularity.FULL,
            )
            print(result["state"])  # Complete final state
            ```

        Note:
            This method uses asyncio.run() internally, so it should not be called
            from within an async context. Use ainvoke() instead for async execution.
        """
        logger.info(
            "Starting synchronous graph execution with %d input keys, granularity=%s",
            len(input_data) if input_data else 0,
            response_granularity,
        )
        logger.debug("Input data keys: %s", list(input_data.keys()) if input_data else [])
        # Async Will Handle Event Publish

        try:
            result = asyncio.run(self.ainvoke(input_data, config, response_granularity))
            logger.info("Synchronous graph execution completed successfully")
            return result
        except Exception as e:
            logger.exception("Synchronous graph execution failed: %s", e)
            raise

    async def ainvoke(
        self,
        input_data: dict[str, Any],
        config: dict[str, Any] | None = None,
        response_granularity: ResponseGranularity = ResponseGranularity.LOW,
    ) -> dict[str, Any]:
        """Execute the graph asynchronously.

        Auto-detects whether to start fresh execution or resume from interrupted state
        based on the AgentState's execution metadata.

        Args:
            input_data: Input dict with 'messages' key (for new execution) or
                       additional data for resuming
            config: Configuration dictionary
            response_granularity: Response parsing granularity

        Returns:
            Response dict based on granularity
        """
        cfg = self._prepare_config(config, is_stream=False)

        return await self._invoke_handler.invoke(
            input_data,
            cfg,
            self._state,
            response_granularity,
        )

    def stop(self, config: dict[str, Any]) -> dict[str, Any]:
        """Request the current graph execution to stop (sync helper).

        This sets a stop flag in the checkpointer's thread store keyed by thread_id.
        Handlers periodically check this flag and interrupt execution.
        Returns a small status dict.
        """
        return asyncio.run(self.astop(config))

    async def astop(self, config: dict[str, Any]) -> dict[str, Any]:
        """Request the current graph execution to stop (async).

        Contract:
        - Requires a valid thread_id in config
        - If no active thread or no checkpointer, returns not-running
        - If state exists and is running, set stop_requested flag in thread info
        """
        cfg = self._prepare_config(config, is_stream=bool(config.get("is_stream", False)))
        if not self._checkpointer:
            return {"ok": False, "reason": "no-checkpointer"}

        # Load state to see if this thread is running
        state = await self._checkpointer.aget_state_cache(
            cfg
        ) or await self._checkpointer.aget_state(cfg)
        if not state:
            return {"ok": False, "running": False, "reason": "no-state"}

        running = state.is_running() and not state.is_interrupted()
        # Set stop flag regardless; handlers will act if running
        if running:
            state.execution_meta.stop_current_execution = StopRequestStatus.STOP_REQUESTED
            # update cache
            # Cache update is enough; state will be picked up by running execution
            # As its running, cache will be available immediately
            await self._checkpointer.aput_state_cache(cfg, state)
            # Fixme: consider putting to main state as well
            # await self._checkpointer.aput_state(cfg, state)
            logger.info("Set stop_current_execution flag for thread_id: %s", cfg.get("thread_id"))
            return {"ok": True, "running": running}

        logger.info(
            "No running execution to stop for thread_id: %s (running=%s, interrupted=%s)",
            cfg.get("thread_id"),
            running,
            state.is_interrupted(),
        )
        return {"ok": True, "running": running, "reason": "not-running"}

    def stream(
        self,
        input_data: dict[str, Any],
        config: dict[str, Any] | None = None,
        response_granularity: ResponseGranularity = ResponseGranularity.LOW,
    ) -> Generator[Message]:
        """Execute the graph synchronously with streaming support.

        Yields Message objects containing incremental responses.
        If nodes return streaming responses, yields them directly.
        If nodes return complete responses, simulates streaming by chunking.

        Args:
            input_data: Input dict
            config: Configuration dictionary
            response_granularity: Response parsing granularity

        Yields:
            Message objects with incremental content
        """

        # For sync streaming, we'll use asyncio.run to handle the async implementation
        async def _async_stream():
            async for chunk in self.astream(input_data, config, response_granularity):
                yield chunk

        # Convert async generator to sync iteration with a dedicated event loop
        gen = _async_stream()
        loop = asyncio.new_event_loop()
        policy = asyncio.get_event_loop_policy()
        try:
            previous_loop = policy.get_event_loop()
        except Exception:
            previous_loop = None
        asyncio.set_event_loop(loop)
        logger.info("Synchronous streaming started")

        try:
            while True:
                try:
                    chunk = loop.run_until_complete(gen.__anext__())
                    yield chunk
                except StopAsyncIteration:
                    break
        finally:
            # Attempt to close the async generator cleanly
            with contextlib.suppress(Exception):
                loop.run_until_complete(gen.aclose())  # type: ignore[attr-defined]
            # Restore previous loop if any, then close created loop
            try:
                if previous_loop is not None:
                    asyncio.set_event_loop(previous_loop)
            finally:
                loop.close()
        logger.info("Synchronous streaming completed")

    async def astream(
        self,
        input_data: dict[str, Any],
        config: dict[str, Any] | None = None,
        response_granularity: ResponseGranularity = ResponseGranularity.LOW,
    ) -> AsyncIterator[Message]:
        """Execute the graph asynchronously with streaming support.

        Yields Message objects containing incremental responses.
        If nodes return streaming responses, yields them directly.
        If nodes return complete responses, simulates streaming by chunking.

        Args:
            input_data: Input dict
            config: Configuration dictionary
            response_granularity: Response parsing granularity

        Yields:
            Message objects with incremental content
        """

        cfg = self._prepare_config(config, is_stream=True)

        async for chunk in self._stream_handler.stream(
            input_data,
            cfg,
            self._state,
            response_granularity,
        ):
            yield chunk

    async def aclose(self) -> dict[str, str]:
        """Close the graph and release any resources."""
        # close checkpointer
        stats = {}
        try:
            if self._checkpointer:
                await self._checkpointer.arelease()
                logger.info("Checkpointer closed successfully")
                stats["checkpointer"] = "closed"
        except Exception as e:
            stats["checkpointer"] = f"error: {e}"
            logger.error(f"Error closing graph: {e}")

        # Close Publisher
        try:
            if self._publisher:
                await self._publisher.close()
                logger.info("Publisher closed successfully")
                stats["publisher"] = "closed"
        except Exception as e:
            stats["publisher"] = f"error: {e}"
            logger.error(f"Error closing publisher: {e}")

        # Close Store
        try:
            if self._store:
                await self._store.arelease()
                logger.info("Store closed successfully")
                stats["store"] = "closed"
        except Exception as e:
            stats["store"] = f"error: {e}"
            logger.error(f"Error closing store: {e}")

        # Wait for all background tasks to complete
        try:
            await self._task_manager.wait_for_all()
            logger.info("All background tasks completed successfully")
            stats["background_tasks"] = "completed"
        except Exception as e:
            stats["background_tasks"] = f"error: {e}"
            logger.error(f"Error waiting for background tasks: {e}")

        logger.info(f"Graph close stats: {stats}")
        # You can also return or process the stats as needed
        return stats

    def generate_graph(self) -> dict[str, Any]:
        """Generate the graph representation.

        Returns:
            A dictionary representing the graph structure.
        """
        graph = {
            "info": {},
            "nodes": [],
            "edges": [],
        }
        # Populate the graph with nodes and edges
        for node_name in self._state_graph.nodes:
            graph["nodes"].append(
                {
                    "id": str(uuid4()),
                    "name": node_name,
                }
            )

        for edge in self._state_graph.edges:
            graph["edges"].append(
                {
                    "id": str(uuid4()),
                    "source": edge.from_node,
                    "target": edge.to_node,
                }
            )

        # Add few more extra info
        graph["info"] = {
            "node_count": len(graph["nodes"]),
            "edge_count": len(graph["edges"]),
            "checkpointer": self._checkpointer is not None,
            "checkpointer_type": type(self._checkpointer).__name__ if self._checkpointer else None,
            "publisher": self._publisher is not None,
            "store": self._store is not None,
            "interrupt_before": self._interrupt_before,
            "interrupt_after": self._interrupt_after,
            "context_type": self._state_graph._context_manager.__class__.__name__,
            "id_generator": self._state_graph._id_generator.__class__.__name__,
            "id_type": self._state_graph._id_generator.id_type.value,
            "state_type": self._state.__class__.__name__,
            "state_fields": list(self._state.model_dump().keys()),
        }
        return graph
Functions
__init__
__init__(state, checkpointer, publisher, store, state_graph, interrupt_before, interrupt_after, task_manager)
Source code in pyagenity/graph/compiled_graph.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def __init__(
    self,
    state: StateT,
    checkpointer: BaseCheckpointer[StateT] | None,
    publisher: BasePublisher | None,
    store: BaseStore | None,
    state_graph: StateGraph[StateT],
    interrupt_before: list[str],
    interrupt_after: list[str],
    task_manager: BackgroundTaskManager,
):
    logger.info(
        f"Initializing CompiledGraph with nodes: {list(state_graph.nodes.keys())}",
    )

    # Save initial state
    self._state = state

    # create handlers
    self._invoke_handler: InvokeHandler[StateT] = InvokeHandler[StateT](
        nodes=state_graph.nodes,  # type: ignore
        edges=state_graph.edges,  # type: ignore
    )
    self._stream_handler: StreamHandler[StateT] = StreamHandler[StateT](
        nodes=state_graph.nodes,  # type: ignore
        edges=state_graph.edges,  # type: ignore
    )

    self._checkpointer: BaseCheckpointer[StateT] | None = checkpointer
    self._publisher: BasePublisher | None = publisher
    self._store: BaseStore | None = store
    self._state_graph: StateGraph[StateT] = state_graph
    self._interrupt_before: list[str] = interrupt_before
    self._interrupt_after: list[str] = interrupt_after
    # generate task manager
    self._task_manager = task_manager
aclose async
aclose()

Close the graph and release any resources.

Source code in pyagenity/graph/compiled_graph.py
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
async def aclose(self) -> dict[str, str]:
    """Close the graph and release any resources."""
    # close checkpointer
    stats = {}
    try:
        if self._checkpointer:
            await self._checkpointer.arelease()
            logger.info("Checkpointer closed successfully")
            stats["checkpointer"] = "closed"
    except Exception as e:
        stats["checkpointer"] = f"error: {e}"
        logger.error(f"Error closing graph: {e}")

    # Close Publisher
    try:
        if self._publisher:
            await self._publisher.close()
            logger.info("Publisher closed successfully")
            stats["publisher"] = "closed"
    except Exception as e:
        stats["publisher"] = f"error: {e}"
        logger.error(f"Error closing publisher: {e}")

    # Close Store
    try:
        if self._store:
            await self._store.arelease()
            logger.info("Store closed successfully")
            stats["store"] = "closed"
    except Exception as e:
        stats["store"] = f"error: {e}"
        logger.error(f"Error closing store: {e}")

    # Wait for all background tasks to complete
    try:
        await self._task_manager.wait_for_all()
        logger.info("All background tasks completed successfully")
        stats["background_tasks"] = "completed"
    except Exception as e:
        stats["background_tasks"] = f"error: {e}"
        logger.error(f"Error waiting for background tasks: {e}")

    logger.info(f"Graph close stats: {stats}")
    # You can also return or process the stats as needed
    return stats
ainvoke async
ainvoke(input_data, config=None, response_granularity=ResponseGranularity.LOW)

Execute the graph asynchronously.

Auto-detects whether to start fresh execution or resume from interrupted state based on the AgentState's execution metadata.

Parameters:

Name Type Description Default
input_data dict[str, Any]

Input dict with 'messages' key (for new execution) or additional data for resuming

required
config dict[str, Any] | None

Configuration dictionary

None
response_granularity ResponseGranularity

Response parsing granularity

LOW

Returns:

Type Description
dict[str, Any]

Response dict based on granularity

Source code in pyagenity/graph/compiled_graph.py
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
async def ainvoke(
    self,
    input_data: dict[str, Any],
    config: dict[str, Any] | None = None,
    response_granularity: ResponseGranularity = ResponseGranularity.LOW,
) -> dict[str, Any]:
    """Execute the graph asynchronously.

    Auto-detects whether to start fresh execution or resume from interrupted state
    based on the AgentState's execution metadata.

    Args:
        input_data: Input dict with 'messages' key (for new execution) or
                   additional data for resuming
        config: Configuration dictionary
        response_granularity: Response parsing granularity

    Returns:
        Response dict based on granularity
    """
    cfg = self._prepare_config(config, is_stream=False)

    return await self._invoke_handler.invoke(
        input_data,
        cfg,
        self._state,
        response_granularity,
    )
astop async
astop(config)

Request the current graph execution to stop (async).

Contract: - Requires a valid thread_id in config - If no active thread or no checkpointer, returns not-running - If state exists and is running, set stop_requested flag in thread info

Source code in pyagenity/graph/compiled_graph.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
async def astop(self, config: dict[str, Any]) -> dict[str, Any]:
    """Request the current graph execution to stop (async).

    Contract:
    - Requires a valid thread_id in config
    - If no active thread or no checkpointer, returns not-running
    - If state exists and is running, set stop_requested flag in thread info
    """
    cfg = self._prepare_config(config, is_stream=bool(config.get("is_stream", False)))
    if not self._checkpointer:
        return {"ok": False, "reason": "no-checkpointer"}

    # Load state to see if this thread is running
    state = await self._checkpointer.aget_state_cache(
        cfg
    ) or await self._checkpointer.aget_state(cfg)
    if not state:
        return {"ok": False, "running": False, "reason": "no-state"}

    running = state.is_running() and not state.is_interrupted()
    # Set stop flag regardless; handlers will act if running
    if running:
        state.execution_meta.stop_current_execution = StopRequestStatus.STOP_REQUESTED
        # update cache
        # Cache update is enough; state will be picked up by running execution
        # As its running, cache will be available immediately
        await self._checkpointer.aput_state_cache(cfg, state)
        # Fixme: consider putting to main state as well
        # await self._checkpointer.aput_state(cfg, state)
        logger.info("Set stop_current_execution flag for thread_id: %s", cfg.get("thread_id"))
        return {"ok": True, "running": running}

    logger.info(
        "No running execution to stop for thread_id: %s (running=%s, interrupted=%s)",
        cfg.get("thread_id"),
        running,
        state.is_interrupted(),
    )
    return {"ok": True, "running": running, "reason": "not-running"}
astream async
astream(input_data, config=None, response_granularity=ResponseGranularity.LOW)

Execute the graph asynchronously with streaming support.

Yields Message objects containing incremental responses. If nodes return streaming responses, yields them directly. If nodes return complete responses, simulates streaming by chunking.

Parameters:

Name Type Description Default
input_data dict[str, Any]

Input dict

required
config dict[str, Any] | None

Configuration dictionary

None
response_granularity ResponseGranularity

Response parsing granularity

LOW

Yields:

Type Description
AsyncIterator[Message]

Message objects with incremental content

Source code in pyagenity/graph/compiled_graph.py
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
async def astream(
    self,
    input_data: dict[str, Any],
    config: dict[str, Any] | None = None,
    response_granularity: ResponseGranularity = ResponseGranularity.LOW,
) -> AsyncIterator[Message]:
    """Execute the graph asynchronously with streaming support.

    Yields Message objects containing incremental responses.
    If nodes return streaming responses, yields them directly.
    If nodes return complete responses, simulates streaming by chunking.

    Args:
        input_data: Input dict
        config: Configuration dictionary
        response_granularity: Response parsing granularity

    Yields:
        Message objects with incremental content
    """

    cfg = self._prepare_config(config, is_stream=True)

    async for chunk in self._stream_handler.stream(
        input_data,
        cfg,
        self._state,
        response_granularity,
    ):
        yield chunk
generate_graph
generate_graph()

Generate the graph representation.

Returns:

Type Description
dict[str, Any]

A dictionary representing the graph structure.

Source code in pyagenity/graph/compiled_graph.py
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
def generate_graph(self) -> dict[str, Any]:
    """Generate the graph representation.

    Returns:
        A dictionary representing the graph structure.
    """
    graph = {
        "info": {},
        "nodes": [],
        "edges": [],
    }
    # Populate the graph with nodes and edges
    for node_name in self._state_graph.nodes:
        graph["nodes"].append(
            {
                "id": str(uuid4()),
                "name": node_name,
            }
        )

    for edge in self._state_graph.edges:
        graph["edges"].append(
            {
                "id": str(uuid4()),
                "source": edge.from_node,
                "target": edge.to_node,
            }
        )

    # Add few more extra info
    graph["info"] = {
        "node_count": len(graph["nodes"]),
        "edge_count": len(graph["edges"]),
        "checkpointer": self._checkpointer is not None,
        "checkpointer_type": type(self._checkpointer).__name__ if self._checkpointer else None,
        "publisher": self._publisher is not None,
        "store": self._store is not None,
        "interrupt_before": self._interrupt_before,
        "interrupt_after": self._interrupt_after,
        "context_type": self._state_graph._context_manager.__class__.__name__,
        "id_generator": self._state_graph._id_generator.__class__.__name__,
        "id_type": self._state_graph._id_generator.id_type.value,
        "state_type": self._state.__class__.__name__,
        "state_fields": list(self._state.model_dump().keys()),
    }
    return graph
invoke
invoke(input_data, config=None, response_granularity=ResponseGranularity.LOW)

Execute the graph synchronously and return the final results.

Runs the complete graph workflow from start to finish, handling state management, node execution, and result formatting. This method automatically detects whether to start a fresh execution or resume from an interrupted state.

The execution is synchronous but internally uses async operations, making it suitable for use in non-async contexts while still benefiting from async capabilities for I/O operations.

Parameters:

Name Type Description Default
input_data dict[str, Any]

Input dictionary for graph execution. For new executions, should contain 'messages' key with list of initial messages. For resumed executions, can contain additional data to merge.

required
config dict[str, Any] | None

Optional configuration dictionary containing execution settings: - user_id: Identifier for the user/session - thread_id: Unique identifier for this execution thread - run_id: Unique identifier for this specific run - recursion_limit: Maximum steps before stopping (default: 25)

None
response_granularity ResponseGranularity

Level of detail in the response: - LOW: Returns only messages (default) - PARTIAL: Returns context, summary, and messages - FULL: Returns complete state and messages

LOW

Returns:

Type Description
dict[str, Any]

Dictionary containing execution results formatted according to the

dict[str, Any]

specified granularity level. Always includes execution messages

dict[str, Any]

and may include additional state information.

Raises:

Type Description
ValueError

If input_data is invalid for new execution.

GraphRecursionError

If execution exceeds recursion limit.

Various exceptions

Depending on node execution failures.

Example
# Basic execution
result = compiled.invoke({"messages": [Message.text_message("Process this data")]})
print(result["messages"])  # Final execution messages

# With configuration and full details
result = compiled.invoke(
    input_data={"messages": [message]},
    config={"user_id": "user123", "thread_id": "session456", "recursion_limit": 50},
    response_granularity=ResponseGranularity.FULL,
)
print(result["state"])  # Complete final state
Note

This method uses asyncio.run() internally, so it should not be called from within an async context. Use ainvoke() instead for async execution.

Source code in pyagenity/graph/compiled_graph.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
def invoke(
    self,
    input_data: dict[str, Any],
    config: dict[str, Any] | None = None,
    response_granularity: ResponseGranularity = ResponseGranularity.LOW,
) -> dict[str, Any]:
    """Execute the graph synchronously and return the final results.

    Runs the complete graph workflow from start to finish, handling state
    management, node execution, and result formatting. This method automatically
    detects whether to start a fresh execution or resume from an interrupted state.

    The execution is synchronous but internally uses async operations, making it
    suitable for use in non-async contexts while still benefiting from async
    capabilities for I/O operations.

    Args:
        input_data: Input dictionary for graph execution. For new executions,
            should contain 'messages' key with list of initial messages.
            For resumed executions, can contain additional data to merge.
        config: Optional configuration dictionary containing execution settings:
            - user_id: Identifier for the user/session
            - thread_id: Unique identifier for this execution thread
            - run_id: Unique identifier for this specific run
            - recursion_limit: Maximum steps before stopping (default: 25)
        response_granularity: Level of detail in the response:
            - LOW: Returns only messages (default)
            - PARTIAL: Returns context, summary, and messages
            - FULL: Returns complete state and messages

    Returns:
        Dictionary containing execution results formatted according to the
        specified granularity level. Always includes execution messages
        and may include additional state information.

    Raises:
        ValueError: If input_data is invalid for new execution.
        GraphRecursionError: If execution exceeds recursion limit.
        Various exceptions: Depending on node execution failures.

    Example:
        ```python
        # Basic execution
        result = compiled.invoke({"messages": [Message.text_message("Process this data")]})
        print(result["messages"])  # Final execution messages

        # With configuration and full details
        result = compiled.invoke(
            input_data={"messages": [message]},
            config={"user_id": "user123", "thread_id": "session456", "recursion_limit": 50},
            response_granularity=ResponseGranularity.FULL,
        )
        print(result["state"])  # Complete final state
        ```

    Note:
        This method uses asyncio.run() internally, so it should not be called
        from within an async context. Use ainvoke() instead for async execution.
    """
    logger.info(
        "Starting synchronous graph execution with %d input keys, granularity=%s",
        len(input_data) if input_data else 0,
        response_granularity,
    )
    logger.debug("Input data keys: %s", list(input_data.keys()) if input_data else [])
    # Async Will Handle Event Publish

    try:
        result = asyncio.run(self.ainvoke(input_data, config, response_granularity))
        logger.info("Synchronous graph execution completed successfully")
        return result
    except Exception as e:
        logger.exception("Synchronous graph execution failed: %s", e)
        raise
stop
stop(config)

Request the current graph execution to stop (sync helper).

This sets a stop flag in the checkpointer's thread store keyed by thread_id. Handlers periodically check this flag and interrupt execution. Returns a small status dict.

Source code in pyagenity/graph/compiled_graph.py
251
252
253
254
255
256
257
258
def stop(self, config: dict[str, Any]) -> dict[str, Any]:
    """Request the current graph execution to stop (sync helper).

    This sets a stop flag in the checkpointer's thread store keyed by thread_id.
    Handlers periodically check this flag and interrupt execution.
    Returns a small status dict.
    """
    return asyncio.run(self.astop(config))
stream
stream(input_data, config=None, response_granularity=ResponseGranularity.LOW)

Execute the graph synchronously with streaming support.

Yields Message objects containing incremental responses. If nodes return streaming responses, yields them directly. If nodes return complete responses, simulates streaming by chunking.

Parameters:

Name Type Description Default
input_data dict[str, Any]

Input dict

required
config dict[str, Any] | None

Configuration dictionary

None
response_granularity ResponseGranularity

Response parsing granularity

LOW

Yields:

Type Description
Generator[Message]

Message objects with incremental content

Source code in pyagenity/graph/compiled_graph.py
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
def stream(
    self,
    input_data: dict[str, Any],
    config: dict[str, Any] | None = None,
    response_granularity: ResponseGranularity = ResponseGranularity.LOW,
) -> Generator[Message]:
    """Execute the graph synchronously with streaming support.

    Yields Message objects containing incremental responses.
    If nodes return streaming responses, yields them directly.
    If nodes return complete responses, simulates streaming by chunking.

    Args:
        input_data: Input dict
        config: Configuration dictionary
        response_granularity: Response parsing granularity

    Yields:
        Message objects with incremental content
    """

    # For sync streaming, we'll use asyncio.run to handle the async implementation
    async def _async_stream():
        async for chunk in self.astream(input_data, config, response_granularity):
            yield chunk

    # Convert async generator to sync iteration with a dedicated event loop
    gen = _async_stream()
    loop = asyncio.new_event_loop()
    policy = asyncio.get_event_loop_policy()
    try:
        previous_loop = policy.get_event_loop()
    except Exception:
        previous_loop = None
    asyncio.set_event_loop(loop)
    logger.info("Synchronous streaming started")

    try:
        while True:
            try:
                chunk = loop.run_until_complete(gen.__anext__())
                yield chunk
            except StopAsyncIteration:
                break
    finally:
        # Attempt to close the async generator cleanly
        with contextlib.suppress(Exception):
            loop.run_until_complete(gen.aclose())  # type: ignore[attr-defined]
        # Restore previous loop if any, then close created loop
        try:
            if previous_loop is not None:
                asyncio.set_event_loop(previous_loop)
        finally:
            loop.close()
    logger.info("Synchronous streaming completed")
edge

Graph edge representation and routing logic for PyAgenity workflows.

This module defines the Edge class, which represents connections between nodes in a PyAgenity graph workflow. Edges can be either static (always followed) or conditional (followed only when certain conditions are met), enabling complex routing logic and decision-making within graph execution.

Edges are fundamental building blocks that define the flow of execution through a graph, determining which node should execute next based on the current state and any conditional logic.

Classes:

Name Description
Edge

Represents a connection between two nodes in a graph workflow.

Attributes:

Name Type Description
logger
Attributes
logger module-attribute
logger = getLogger(__name__)
Classes
Edge

Represents a connection between two nodes in a graph workflow.

An Edge defines the relationship and routing logic between nodes, specifying how execution should flow from one node to another. Edges can be either static (unconditional) or conditional based on runtime state evaluation.

Edges support complex routing scenarios including: - Simple static connections between nodes - Conditional routing based on state evaluation - Dynamic routing with multiple possible destinations - Decision trees and branching logic

Attributes:

Name Type Description
from_node

Name of the source node where execution originates.

to_node

Name of the destination node where execution continues.

condition

Optional callable that determines if this edge should be followed. If None, the edge is always followed (static edge).

condition_result str | None

Optional value to match against condition result for mapped conditional edges.

Example
# Static edge - always followed
static_edge = Edge("start", "process")


# Conditional edge - followed only if condition returns True
def needs_approval(state):
    return state.data.get("requires_approval", False)


conditional_edge = Edge("process", "approval", condition=needs_approval)


# Mapped conditional edge - follows based on specific condition result
def get_priority(state):
    return state.data.get("priority", "normal")


high_priority_edge = Edge("triage", "urgent", condition=get_priority)
high_priority_edge.condition_result = "high"

Methods:

Name Description
__init__

Initialize a new Edge with source, destination, and optional condition.

Source code in pyagenity/graph/edge.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
class Edge:
    """Represents a connection between two nodes in a graph workflow.

    An Edge defines the relationship and routing logic between nodes, specifying
    how execution should flow from one node to another. Edges can be either
    static (unconditional) or conditional based on runtime state evaluation.

    Edges support complex routing scenarios including:
    - Simple static connections between nodes
    - Conditional routing based on state evaluation
    - Dynamic routing with multiple possible destinations
    - Decision trees and branching logic

    Attributes:
        from_node: Name of the source node where execution originates.
        to_node: Name of the destination node where execution continues.
        condition: Optional callable that determines if this edge should be
            followed. If None, the edge is always followed (static edge).
        condition_result: Optional value to match against condition result
            for mapped conditional edges.

    Example:
        ```python
        # Static edge - always followed
        static_edge = Edge("start", "process")


        # Conditional edge - followed only if condition returns True
        def needs_approval(state):
            return state.data.get("requires_approval", False)


        conditional_edge = Edge("process", "approval", condition=needs_approval)


        # Mapped conditional edge - follows based on specific condition result
        def get_priority(state):
            return state.data.get("priority", "normal")


        high_priority_edge = Edge("triage", "urgent", condition=get_priority)
        high_priority_edge.condition_result = "high"
        ```
    """

    def __init__(
        self,
        from_node: str,
        to_node: str,
        condition: Callable | None = None,
    ):
        """Initialize a new Edge with source, destination, and optional condition.

        Args:
            from_node: Name of the source node. Must match a node name in the graph.
            to_node: Name of the destination node. Must match a node name in the graph
                or be a special constant like END.
            condition: Optional callable that takes an AgentState as argument and
                returns a value to determine if this edge should be followed.
                If None, this is a static edge that's always followed.

        Note:
            The condition function should be deterministic and side-effect free
            for predictable execution behavior. It receives the current AgentState
            and should return a boolean (for simple conditions) or a string/value
            (for mapped conditional routing).
        """
        logger.debug(
            "Creating edge from '%s' to '%s' with condition=%s",
            from_node,
            to_node,
            "yes" if condition else "no",
        )
        self.from_node = from_node
        self.to_node = to_node
        self.condition = condition
        self.condition_result: str | None = None
Attributes
condition instance-attribute
condition = condition
condition_result instance-attribute
condition_result = None
from_node instance-attribute
from_node = from_node
to_node instance-attribute
to_node = to_node
Functions
__init__
__init__(from_node, to_node, condition=None)

Initialize a new Edge with source, destination, and optional condition.

Parameters:

Name Type Description Default
from_node str

Name of the source node. Must match a node name in the graph.

required
to_node str

Name of the destination node. Must match a node name in the graph or be a special constant like END.

required
condition Callable | None

Optional callable that takes an AgentState as argument and returns a value to determine if this edge should be followed. If None, this is a static edge that's always followed.

None
Note

The condition function should be deterministic and side-effect free for predictable execution behavior. It receives the current AgentState and should return a boolean (for simple conditions) or a string/value (for mapped conditional routing).

Source code in pyagenity/graph/edge.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def __init__(
    self,
    from_node: str,
    to_node: str,
    condition: Callable | None = None,
):
    """Initialize a new Edge with source, destination, and optional condition.

    Args:
        from_node: Name of the source node. Must match a node name in the graph.
        to_node: Name of the destination node. Must match a node name in the graph
            or be a special constant like END.
        condition: Optional callable that takes an AgentState as argument and
            returns a value to determine if this edge should be followed.
            If None, this is a static edge that's always followed.

    Note:
        The condition function should be deterministic and side-effect free
        for predictable execution behavior. It receives the current AgentState
        and should return a boolean (for simple conditions) or a string/value
        (for mapped conditional routing).
    """
    logger.debug(
        "Creating edge from '%s' to '%s' with condition=%s",
        from_node,
        to_node,
        "yes" if condition else "no",
    )
    self.from_node = from_node
    self.to_node = to_node
    self.condition = condition
    self.condition_result: str | None = None
node

Node execution and management for PyAgenity graph workflows.

This module defines the Node class, which represents executable units within a PyAgenity graph workflow. Nodes encapsulate functions or ToolNode instances that perform specific tasks, handle dependency injection, manage execution context, and support both synchronous and streaming execution modes.

Nodes are the fundamental building blocks of graph workflows, responsible for processing state, executing business logic, and producing outputs that drive the workflow forward. They integrate seamlessly with PyAgenity's dependency injection system and callback management framework.

Classes:

Name Description
Node

Represents a node in the graph workflow.

Attributes:

Name Type Description
logger
Attributes
logger module-attribute
logger = getLogger(__name__)
Classes
Node

Represents a node in the graph workflow.

A Node encapsulates a function or ToolNode that can be executed as part of a graph workflow. It handles dependency injection, parameter mapping, and execution context management.

The Node class supports both regular callable functions and ToolNode instances for handling tool-based operations. It automatically injects dependencies based on function signatures and provides legacy parameter support.

Attributes:

Name Type Description
name str

Unique identifier for the node within the graph.

func Union[Callable, ToolNode]

The function or ToolNode to execute.

Example

def my_function(state, config): ... return {"result": "processed"} node = Node("processor", my_function) result = await node.execute(state, config)

Methods:

Name Description
__init__

Initialize a new Node instance with function and dependencies.

execute

Execute the node function with comprehensive context and callback support.

stream

Execute the node function with streaming output support.

Source code in pyagenity/graph/node.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
class Node:
    """Represents a node in the graph workflow.

    A Node encapsulates a function or ToolNode that can be executed as part of
    a graph workflow. It handles dependency injection, parameter mapping, and
    execution context management.

    The Node class supports both regular callable functions and ToolNode instances
    for handling tool-based operations. It automatically injects dependencies
    based on function signatures and provides legacy parameter support.

    Attributes:
        name (str): Unique identifier for the node within the graph.
        func (Union[Callable, ToolNode]): The function or ToolNode to execute.

    Example:
        >>> def my_function(state, config):
        ...     return {"result": "processed"}
        >>> node = Node("processor", my_function)
        >>> result = await node.execute(state, config)
    """

    def __init__(
        self,
        name: str,
        func: Union[Callable, "ToolNode"],
        publisher: BasePublisher | None = Inject[BasePublisher],
    ):
        """Initialize a new Node instance with function and dependencies.

        Args:
            name: Unique identifier for the node within the graph. This name
                is used for routing, logging, and referencing the node in
                graph configuration.
            func: The function or ToolNode to execute when this node is called.
                Functions should accept at least 'state' and 'config' parameters.
                ToolNode instances handle tool-based operations and provide
                their own execution logic.
            publisher: Optional event publisher for execution monitoring.
                Injected via dependency injection if not explicitly provided.
                Used for publishing node execution events and status updates.

        Note:
            The function signature is automatically analyzed to determine
            required parameters and dependency injection points. Parameters
            matching injectable service names will be automatically provided
            by the framework during execution.
        """
        logger.debug(
            "Initializing node '%s' with func=%s",
            name,
            getattr(func, "__name__", type(func).__name__),
        )
        self.name = name
        self.func = func
        self.publisher = publisher
        self.invoke_handler = InvokeNodeHandler(
            name,
            func,
        )

        self.stream_handler = StreamNodeHandler(
            name,
            func,
        )

    async def execute(
        self,
        config: dict[str, Any],
        state: AgentState,
        callback_mgr: CallbackManager = Inject[CallbackManager],
    ) -> dict[str, Any] | list[Message]:
        """Execute the node function with comprehensive context and callback support.

        Executes the node's function or ToolNode with full dependency injection,
        callback hook execution, and error handling. This method provides the
        complete execution environment including state access, configuration,
        and injected services.

        Args:
            config: Configuration dictionary containing execution context,
                user settings, thread identification, and runtime parameters.
            state: Current AgentState providing workflow context, message history,
                and shared state information accessible to the node function.
            callback_mgr: Callback manager for executing pre/post execution hooks.
                Injected via dependency injection if not explicitly provided.

        Returns:
            Either a dictionary containing updated state and execution results,
            or a list of Message objects representing the node's output.
            The return type depends on the node function's implementation.

        Raises:
            Various exceptions depending on node function behavior. All exceptions
            are handled by the callback manager's error handling hooks before
            being propagated.

        Example:
            ```python
            # Node function that returns messages
            def process_data(state, config):
                result = process(state.data)
                return [Message.text_message(f"Processed: {result}")]


            node = Node("processor", process_data)
            messages = await node.execute(config, state)
            ```

        Note:
            The node function receives dependency-injected parameters based on
            its signature. Common injectable parameters include 'state', 'config',
            'context_manager', 'publisher', and other framework services.
        """
        return await self.invoke_handler.invoke(
            config,
            state,
            callback_mgr,
        )

    async def stream(
        self,
        config: dict[str, Any],
        state: AgentState,
        callback_mgr: CallbackManager = Inject[CallbackManager],
    ) -> AsyncIterable[dict[str, Any] | Message]:
        """Execute the node function with streaming output support.

        Similar to execute() but designed for streaming scenarios where the node
        function can produce incremental results. This method provides an async
        iterator interface over the node's outputs, allowing for real-time
        processing and response streaming.

        Args:
            config: Configuration dictionary with execution context and settings.
            state: Current AgentState providing workflow context and shared state.
            callback_mgr: Callback manager for pre/post execution hook handling.

        Yields:
            Dictionary objects or Message instances representing incremental
            outputs from the node function. The exact type and frequency of
            yields depends on the node function's streaming implementation.

        Example:
            ```python
            async def streaming_processor(state, config):
                for item in large_dataset:
                    result = process_item(item)
                    yield Message.text_message(f"Processed item: {result}")


            node = Node("stream_processor", streaming_processor)
            async for output in node.stream(config, state):
                print(f"Streamed: {output.content}")
            ```

        Note:
            Not all node functions support streaming. For non-streaming functions,
            this method will yield a single result equivalent to calling execute().
            The streaming capability is determined by the node function's implementation.
        """
        result = self.stream_handler.stream(
            config,
            state,
            callback_mgr,
        )

        async for item in result:
            yield item
Attributes
func instance-attribute
func = func
invoke_handler instance-attribute
invoke_handler = InvokeNodeHandler(name, func)
name instance-attribute
name = name
publisher instance-attribute
publisher = publisher
stream_handler instance-attribute
stream_handler = StreamNodeHandler(name, func)
Functions
__init__
__init__(name, func, publisher=Inject[BasePublisher])

Initialize a new Node instance with function and dependencies.

Parameters:

Name Type Description Default
name str

Unique identifier for the node within the graph. This name is used for routing, logging, and referencing the node in graph configuration.

required
func Union[Callable, ToolNode]

The function or ToolNode to execute when this node is called. Functions should accept at least 'state' and 'config' parameters. ToolNode instances handle tool-based operations and provide their own execution logic.

required
publisher BasePublisher | None

Optional event publisher for execution monitoring. Injected via dependency injection if not explicitly provided. Used for publishing node execution events and status updates.

Inject[BasePublisher]
Note

The function signature is automatically analyzed to determine required parameters and dependency injection points. Parameters matching injectable service names will be automatically provided by the framework during execution.

Source code in pyagenity/graph/node.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def __init__(
    self,
    name: str,
    func: Union[Callable, "ToolNode"],
    publisher: BasePublisher | None = Inject[BasePublisher],
):
    """Initialize a new Node instance with function and dependencies.

    Args:
        name: Unique identifier for the node within the graph. This name
            is used for routing, logging, and referencing the node in
            graph configuration.
        func: The function or ToolNode to execute when this node is called.
            Functions should accept at least 'state' and 'config' parameters.
            ToolNode instances handle tool-based operations and provide
            their own execution logic.
        publisher: Optional event publisher for execution monitoring.
            Injected via dependency injection if not explicitly provided.
            Used for publishing node execution events and status updates.

    Note:
        The function signature is automatically analyzed to determine
        required parameters and dependency injection points. Parameters
        matching injectable service names will be automatically provided
        by the framework during execution.
    """
    logger.debug(
        "Initializing node '%s' with func=%s",
        name,
        getattr(func, "__name__", type(func).__name__),
    )
    self.name = name
    self.func = func
    self.publisher = publisher
    self.invoke_handler = InvokeNodeHandler(
        name,
        func,
    )

    self.stream_handler = StreamNodeHandler(
        name,
        func,
    )
execute async
execute(config, state, callback_mgr=Inject[CallbackManager])

Execute the node function with comprehensive context and callback support.

Executes the node's function or ToolNode with full dependency injection, callback hook execution, and error handling. This method provides the complete execution environment including state access, configuration, and injected services.

Parameters:

Name Type Description Default
config dict[str, Any]

Configuration dictionary containing execution context, user settings, thread identification, and runtime parameters.

required
state AgentState

Current AgentState providing workflow context, message history, and shared state information accessible to the node function.

required
callback_mgr CallbackManager

Callback manager for executing pre/post execution hooks. Injected via dependency injection if not explicitly provided.

Inject[CallbackManager]

Returns:

Type Description
dict[str, Any] | list[Message]

Either a dictionary containing updated state and execution results,

dict[str, Any] | list[Message]

or a list of Message objects representing the node's output.

dict[str, Any] | list[Message]

The return type depends on the node function's implementation.

Example
# Node function that returns messages
def process_data(state, config):
    result = process(state.data)
    return [Message.text_message(f"Processed: {result}")]


node = Node("processor", process_data)
messages = await node.execute(config, state)
Note

The node function receives dependency-injected parameters based on its signature. Common injectable parameters include 'state', 'config', 'context_manager', 'publisher', and other framework services.

Source code in pyagenity/graph/node.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
async def execute(
    self,
    config: dict[str, Any],
    state: AgentState,
    callback_mgr: CallbackManager = Inject[CallbackManager],
) -> dict[str, Any] | list[Message]:
    """Execute the node function with comprehensive context and callback support.

    Executes the node's function or ToolNode with full dependency injection,
    callback hook execution, and error handling. This method provides the
    complete execution environment including state access, configuration,
    and injected services.

    Args:
        config: Configuration dictionary containing execution context,
            user settings, thread identification, and runtime parameters.
        state: Current AgentState providing workflow context, message history,
            and shared state information accessible to the node function.
        callback_mgr: Callback manager for executing pre/post execution hooks.
            Injected via dependency injection if not explicitly provided.

    Returns:
        Either a dictionary containing updated state and execution results,
        or a list of Message objects representing the node's output.
        The return type depends on the node function's implementation.

    Raises:
        Various exceptions depending on node function behavior. All exceptions
        are handled by the callback manager's error handling hooks before
        being propagated.

    Example:
        ```python
        # Node function that returns messages
        def process_data(state, config):
            result = process(state.data)
            return [Message.text_message(f"Processed: {result}")]


        node = Node("processor", process_data)
        messages = await node.execute(config, state)
        ```

    Note:
        The node function receives dependency-injected parameters based on
        its signature. Common injectable parameters include 'state', 'config',
        'context_manager', 'publisher', and other framework services.
    """
    return await self.invoke_handler.invoke(
        config,
        state,
        callback_mgr,
    )
stream async
stream(config, state, callback_mgr=Inject[CallbackManager])

Execute the node function with streaming output support.

Similar to execute() but designed for streaming scenarios where the node function can produce incremental results. This method provides an async iterator interface over the node's outputs, allowing for real-time processing and response streaming.

Parameters:

Name Type Description Default
config dict[str, Any]

Configuration dictionary with execution context and settings.

required
state AgentState

Current AgentState providing workflow context and shared state.

required
callback_mgr CallbackManager

Callback manager for pre/post execution hook handling.

Inject[CallbackManager]

Yields:

Type Description
AsyncIterable[dict[str, Any] | Message]

Dictionary objects or Message instances representing incremental

AsyncIterable[dict[str, Any] | Message]

outputs from the node function. The exact type and frequency of

AsyncIterable[dict[str, Any] | Message]

yields depends on the node function's streaming implementation.

Example
async def streaming_processor(state, config):
    for item in large_dataset:
        result = process_item(item)
        yield Message.text_message(f"Processed item: {result}")


node = Node("stream_processor", streaming_processor)
async for output in node.stream(config, state):
    print(f"Streamed: {output.content}")
Note

Not all node functions support streaming. For non-streaming functions, this method will yield a single result equivalent to calling execute(). The streaming capability is determined by the node function's implementation.

Source code in pyagenity/graph/node.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
async def stream(
    self,
    config: dict[str, Any],
    state: AgentState,
    callback_mgr: CallbackManager = Inject[CallbackManager],
) -> AsyncIterable[dict[str, Any] | Message]:
    """Execute the node function with streaming output support.

    Similar to execute() but designed for streaming scenarios where the node
    function can produce incremental results. This method provides an async
    iterator interface over the node's outputs, allowing for real-time
    processing and response streaming.

    Args:
        config: Configuration dictionary with execution context and settings.
        state: Current AgentState providing workflow context and shared state.
        callback_mgr: Callback manager for pre/post execution hook handling.

    Yields:
        Dictionary objects or Message instances representing incremental
        outputs from the node function. The exact type and frequency of
        yields depends on the node function's streaming implementation.

    Example:
        ```python
        async def streaming_processor(state, config):
            for item in large_dataset:
                result = process_item(item)
                yield Message.text_message(f"Processed item: {result}")


        node = Node("stream_processor", streaming_processor)
        async for output in node.stream(config, state):
            print(f"Streamed: {output.content}")
        ```

    Note:
        Not all node functions support streaming. For non-streaming functions,
        this method will yield a single result equivalent to calling execute().
        The streaming capability is determined by the node function's implementation.
    """
    result = self.stream_handler.stream(
        config,
        state,
        callback_mgr,
    )

    async for item in result:
        yield item
state_graph

Classes:

Name Description
StateGraph

Main graph class for orchestrating multi-agent workflows.

Attributes:

Name Type Description
StateT
logger
Attributes
StateT module-attribute
StateT = TypeVar('StateT', bound=AgentState)
logger module-attribute
logger = getLogger(__name__)
Classes
StateGraph

Main graph class for orchestrating multi-agent workflows.

This class provides the core functionality for building and managing stateful agent workflows. It is similar to LangGraph's StateGraph integration with support for dependency injection.

The graph is generic over state types to support custom AgentState subclasses, allowing for type-safe state management throughout the workflow execution.

Attributes:

Name Type Description
state StateT

The current state of the graph workflow.

nodes dict[str, Node]

Collection of nodes in the graph.

edges list[Edge]

Collection of edges connecting nodes.

entry_point str | None

Name of the starting node for execution.

context_manager BaseContextManager[StateT] | None

Optional context manager for handling cross-node state operations.

dependency_container DependencyContainer

Container for managing dependencies that can be injected into node functions.

compiled bool

Whether the graph has been compiled for execution.

Example

graph = StateGraph() graph.add_node("process", process_function) graph.add_edge(START, "process") graph.add_edge("process", END) compiled = graph.compile() result = compiled.invoke({"input": "data"})

Methods:

Name Description
__init__

Initialize a new StateGraph instance.

add_conditional_edges

Add conditional routing between nodes based on runtime evaluation.

add_edge

Add a static edge between two nodes.

add_node

Add a node to the graph.

compile

Compile the graph for execution.

set_entry_point

Set the entry point for the graph.

Source code in pyagenity/graph/state_graph.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
class StateGraph[StateT: AgentState]:
    """Main graph class for orchestrating multi-agent workflows.

    This class provides the core functionality for building and managing stateful
    agent workflows. It is similar to LangGraph's StateGraph
    integration with support for dependency injection.

    The graph is generic over state types to support custom AgentState subclasses,
    allowing for type-safe state management throughout the workflow execution.

    Attributes:
        state (StateT): The current state of the graph workflow.
        nodes (dict[str, Node]): Collection of nodes in the graph.
        edges (list[Edge]): Collection of edges connecting nodes.
        entry_point (str | None): Name of the starting node for execution.
        context_manager (BaseContextManager[StateT] | None): Optional context manager
            for handling cross-node state operations.
        dependency_container (DependencyContainer): Container for managing
            dependencies that can be injected into node functions.
        compiled (bool): Whether the graph has been compiled for execution.

    Example:
        >>> graph = StateGraph()
        >>> graph.add_node("process", process_function)
        >>> graph.add_edge(START, "process")
        >>> graph.add_edge("process", END)
        >>> compiled = graph.compile()
        >>> result = compiled.invoke({"input": "data"})
    """

    def __init__(
        self,
        state: StateT | None = None,
        context_manager: BaseContextManager[StateT] | None = None,
        publisher: BasePublisher | None = None,
        id_generator: BaseIDGenerator = DefaultIDGenerator(),
        container: InjectQ | None = None,
        thread_name_generator: Callable[[], str] | None = None,
    ):
        """Initialize a new StateGraph instance.

        Args:
            state: Initial state for the graph. If None, a default AgentState
                will be created.
            context_manager: Optional context manager for handling cross-node
                state operations and advanced state management patterns.
            dependency_container: Container for managing dependencies that can
                be injected into node functions. If None, a new empty container
                will be created.
            publisher: Publisher for emitting events during execution

        Note:
            START and END nodes are automatically added to the graph upon
            initialization and accept the full node signature including
            dependencies.

        Example:
            # Basic usage with default AgentState
            >>> graph = StateGraph()

            # With custom state
            >>> custom_state = MyCustomState()
            >>> graph = StateGraph(custom_state)

            # Or using type hints for clarity
            >>> graph = StateGraph[MyCustomState](MyCustomState())
        """
        logger.info("Initializing StateGraph")
        logger.debug(
            "StateGraph init with state=%s, context_manager=%s",
            type(state).__name__ if state else "default AgentState",
            type(context_manager).__name__ if context_manager else None,
        )

        # State handling
        self._state: StateT = state if state else AgentState()  # type: ignore[assignment]

        # Graph structure
        self.nodes: dict[str, Node] = {}
        self.edges: list[Edge] = []
        self.entry_point: str | None = None

        # Services
        self._publisher: BasePublisher | None = publisher
        self._id_generator: BaseIDGenerator = id_generator
        self._context_manager: BaseContextManager[StateT] | None = context_manager
        self.thread_name_generator = thread_name_generator
        # save container for dependency injection
        # if any container is passed then we will activate that
        # otherwise we can skip it and use the default one
        if container is None:
            self._container = InjectQ.get_instance()
            logger.debug("No container provided, using global singleton instance")
        else:
            logger.debug("Using provided dependency container instance")
            self._container = container
            self._container.activate()

        # Register task_manager, for async tasks
        # This will be used to run background tasks
        self._task_manager = BackgroundTaskManager()

        # now setup the graph
        self._setup()

        # Add START and END nodes (accept full node signature including dependencies)
        logger.debug("Adding default START and END nodes")
        self.nodes[START] = Node(START, lambda state, config, **deps: state, self._publisher)  # type: ignore
        self.nodes[END] = Node(END, lambda state, config, **deps: state, self._publisher)
        logger.debug("StateGraph initialized with %d nodes", len(self.nodes))

    def _setup(self):
        """Setup the graph before compilation.

        This method can be used to perform any necessary setup or validation
        before compiling the graph for execution.
        """
        logger.info("Setting up StateGraph before compilation")
        # Placeholder for any setup logic needed before compilation
        # register dependencies

        # register state and context manager as singletons (these are nullable)
        self._container.bind_instance(
            BaseContextManager,
            self._context_manager,
            allow_none=True,
            allow_concrete=True,
        )
        self._container.bind_instance(
            BasePublisher,
            self._publisher,
            allow_none=True,
            allow_concrete=True,
        )

        # register id generator as factory
        self._container.bind_instance(
            BaseIDGenerator,
            self._id_generator,
            allow_concrete=True,
        )
        self._container.bind("generated_id_type", self._id_generator.id_type)
        # Allow async method also
        self._container.bind_factory(
            "generated_id",
            lambda: self._id_generator.generate(),
        )

        # Attach Thread name generator if provided
        if self.thread_name_generator is None:
            self.thread_name_generator = generate_dummy_thread_name

        generator = self.thread_name_generator or generate_dummy_thread_name

        self._container.bind_factory(
            "generated_thread_name",
            lambda: generator(),
        )

        # Save BackgroundTaskManager
        self._container.bind_instance(
            BackgroundTaskManager,
            self._task_manager,
            allow_concrete=False,
        )

    def add_node(
        self,
        name_or_func: str | Callable,
        func: Union[Callable, "ToolNode", None] = None,
    ) -> "StateGraph":
        """Add a node to the graph.

        This method supports two calling patterns:
        1. Pass a callable as the first argument (name inferred from function name)
        2. Pass a name string and callable/ToolNode as separate arguments

        Args:
            name_or_func: Either the node name (str) or a callable function.
                If callable, the function name will be used as the node name.
            func: The function or ToolNode to execute. Required if name_or_func
                is a string, ignored if name_or_func is callable.

        Returns:
            StateGraph: The graph instance for method chaining.

        Raises:
            ValueError: If invalid arguments are provided.

        Example:
            >>> # Method 1: Function name inferred
            >>> graph.add_node(my_function)
            >>> # Method 2: Explicit name and function
            >>> graph.add_node("process", my_function)
        """
        if callable(name_or_func) and func is None:
            # Function passed as first argument
            name = name_or_func.__name__
            func = name_or_func
            logger.debug("Adding node '%s' with inferred name from function", name)
        elif isinstance(name_or_func, str) and (callable(func) or isinstance(func, ToolNode)):
            # Name and function passed separately
            name = name_or_func
            logger.debug(
                "Adding node '%s' with explicit name and %s",
                name,
                "ToolNode" if isinstance(func, ToolNode) else "callable",
            )
        else:
            error_msg = "Invalid arguments for add_node"
            logger.error(error_msg)
            raise ValueError(error_msg)

        self.nodes[name] = Node(name, func)
        logger.info("Added node '%s' to graph (total nodes: %d)", name, len(self.nodes))
        return self

    def add_edge(
        self,
        from_node: str,
        to_node: str,
    ) -> "StateGraph":
        """Add a static edge between two nodes.

        Creates a direct connection from one node to another. If the source
        node is START, the target node becomes the entry point for the graph.

        Args:
            from_node: Name of the source node.
            to_node: Name of the target node.

        Returns:
            StateGraph: The graph instance for method chaining.

        Example:
            >>> graph.add_edge("node1", "node2")
            >>> graph.add_edge(START, "entry_node")  # Sets entry point
        """
        logger.debug("Adding edge from '%s' to '%s'", from_node, to_node)
        # Set entry point if edge is from START
        if from_node == START:
            self.entry_point = to_node
            logger.info("Set entry point to '%s'", to_node)
        self.edges.append(Edge(from_node, to_node))
        logger.debug("Added edge (total edges: %d)", len(self.edges))
        return self

    def add_conditional_edges(
        self,
        from_node: str,
        condition: Callable,
        path_map: dict[str, str] | None = None,
    ) -> "StateGraph":
        """Add conditional routing between nodes based on runtime evaluation.

        Creates dynamic routing logic where the next node is determined by evaluating
        a condition function against the current state. This enables complex branching
        logic, decision trees, and adaptive workflow routing.

        Args:
            from_node: Name of the source node where the condition is evaluated.
            condition: Callable function that takes the current AgentState and returns
                a value used for routing decisions. Should be deterministic and
                side-effect free.
            path_map: Optional dictionary mapping condition results to destination nodes.
                If provided, the condition's return value is looked up in this mapping.
                If None, the condition should return the destination node name directly.

        Returns:
            StateGraph: The graph instance for method chaining.

        Raises:
            ValueError: If the condition function or path_map configuration is invalid.

        Example:
            ```python
            # Direct routing - condition returns node name
            def route_by_priority(state):
                priority = state.data.get("priority", "normal")
                return "urgent_handler" if priority == "high" else "normal_handler"


            graph.add_conditional_edges("classifier", route_by_priority)


            # Mapped routing - condition result mapped to nodes
            def get_category(state):
                return state.data.get("category", "default")


            category_map = {
                "finance": "finance_processor",
                "legal": "legal_processor",
                "default": "general_processor",
            }
            graph.add_conditional_edges("categorizer", get_category, category_map)
            ```

        Note:
            The condition function receives the current AgentState and should return
            consistent results for the same state. If using path_map, ensure the
            condition's return values match the map keys exactly.
        """
        """Add conditional edges from a node based on a condition function.

        Creates edges that are traversed based on the result of a condition
        function. The condition function receives the current state and should
        return a value that determines which edge to follow.

        Args:
            from_node: Name of the source node.
            condition: Function that evaluates the current state and returns
                a value to determine the next node.
            path_map: Optional mapping from condition results to target nodes.
                If provided, creates multiple conditional edges. If None,
                creates a single conditional edge.

        Returns:
            StateGraph: The graph instance for method chaining.

        Example:
            >>> def route_condition(state):
            ...     return "success" if state.success else "failure"
            >>> graph.add_conditional_edges(
            ...     "processor",
            ...     route_condition,
            ...     {"success": "next_step", "failure": "error_handler"},
            ... )
        """
        # Create edges based on possible returns from condition function
        logger.debug(
            "Node '%s' adding conditional edges with path_map: %s",
            from_node,
            path_map,
        )
        if path_map:
            logger.debug(
                "Node '%s' adding conditional edges with path_map: %s", from_node, path_map
            )
            for condition_result, target_node in path_map.items():
                edge = Edge(from_node, target_node, condition)
                edge.condition_result = condition_result
                self.edges.append(edge)
        else:
            # Single conditional edge
            logger.debug("Node '%s' adding single conditional edge", from_node)
            self.edges.append(Edge(from_node, "", condition))
        return self

    def set_entry_point(self, node_name: str) -> "StateGraph":
        """Set the entry point for the graph."""
        self.entry_point = node_name
        self.add_edge(START, node_name)
        logger.info("Set entry point to '%s'", node_name)
        return self

    def compile(
        self,
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> "CompiledGraph[StateT]":
        """Compile the graph for execution.

        Args:
            checkpointer: Checkpointer for state persistence
            store: Store for additional data
            debug: Enable debug mode
            interrupt_before: List of node names to interrupt before execution
            interrupt_after: List of node names to interrupt after execution
            callback_manager: Callback manager for executing hooks
        """
        logger.info(
            "Compiling graph with %d nodes, %d edges, entry_point='%s'",
            len(self.nodes),
            len(self.edges),
            self.entry_point,
        )
        logger.debug(
            "Compile options: interrupt_before=%s, interrupt_after=%s",
            interrupt_before,
            interrupt_after,
        )

        if not self.entry_point:
            error_msg = "No entry point set. Use set_entry_point() or add an edge from START."
            logger.error(error_msg)
            raise GraphError(error_msg)

        # Validate graph structure
        logger.debug("Validating graph structure")
        self._validate_graph()
        logger.debug("Graph structure validated successfully")

        # Validate interrupt node names
        interrupt_before = interrupt_before or []
        interrupt_after = interrupt_after or []

        all_interrupt_nodes = set(interrupt_before + interrupt_after)
        invalid_nodes = all_interrupt_nodes - set(self.nodes.keys())
        if invalid_nodes:
            error_msg = f"Invalid interrupt nodes: {invalid_nodes}. Must be existing node names."
            logger.error(error_msg)
            raise GraphError(error_msg)

        self.compiled = True
        logger.info("Graph compilation completed successfully")
        # Import here to avoid circular import at module import time
        # Now update Checkpointer
        if checkpointer is None:
            from pyagenity.checkpointer import InMemoryCheckpointer

            checkpointer = InMemoryCheckpointer[StateT]()
            logger.debug("No checkpointer provided, using InMemoryCheckpointer")

        # Import the CompiledGraph class
        from .compiled_graph import CompiledGraph

        # Setup dependencies
        self._container.bind_instance(
            BaseCheckpointer,
            checkpointer,
            allow_concrete=True,
        )  # not null as we set default
        self._container.bind_instance(
            BaseStore,
            store,
            allow_none=True,
            allow_concrete=True,
        )
        self._container.bind_instance(
            CallbackManager,
            callback_manager,
            allow_concrete=True,
        )  # not null as we set default
        self._container.bind("interrupt_before", interrupt_before)
        self._container.bind("interrupt_after", interrupt_after)
        self._container.bind_instance(StateGraph, self)

        app = CompiledGraph(
            state=self._state,
            interrupt_after=interrupt_after,
            interrupt_before=interrupt_before,
            state_graph=self,
            checkpointer=checkpointer,
            publisher=self._publisher,
            store=store,
            task_manager=self._task_manager,
        )

        self._container.bind(CompiledGraph, app)
        # Compile the Graph, so it will optimize the dependency graph
        self._container.compile()
        return app

    def _validate_graph(self):
        """Validate the graph structure."""
        # Check for orphaned nodes
        connected_nodes = set()
        for edge in self.edges:
            connected_nodes.add(edge.from_node)
            connected_nodes.add(edge.to_node)

        all_nodes = set(self.nodes.keys())
        orphaned = all_nodes - connected_nodes
        if orphaned - {START, END}:  # START and END can be orphaned
            logger.error("Orphaned nodes detected: %s", orphaned - {START, END})
            raise GraphError(f"Orphaned nodes detected: {orphaned - {START, END}}")

        # Check that all edge targets exist
        for edge in self.edges:
            if edge.to_node and edge.to_node not in self.nodes:
                logger.error("Edge '%s' targets non-existent node: %s", edge, edge.to_node)
                raise GraphError(f"Edge targets non-existent node: {edge.to_node}")
Attributes
edges instance-attribute
edges = []
entry_point instance-attribute
entry_point = None
nodes instance-attribute
nodes = {}
thread_name_generator instance-attribute
thread_name_generator = thread_name_generator
Functions
__init__
__init__(state=None, context_manager=None, publisher=None, id_generator=DefaultIDGenerator(), container=None, thread_name_generator=None)

Initialize a new StateGraph instance.

Parameters:

Name Type Description Default
state StateT | None

Initial state for the graph. If None, a default AgentState will be created.

None
context_manager BaseContextManager[StateT] | None

Optional context manager for handling cross-node state operations and advanced state management patterns.

None
dependency_container

Container for managing dependencies that can be injected into node functions. If None, a new empty container will be created.

required
publisher BasePublisher | None

Publisher for emitting events during execution

None
Note

START and END nodes are automatically added to the graph upon initialization and accept the full node signature including dependencies.

Example
Basic usage with default AgentState

graph = StateGraph()

With custom state

custom_state = MyCustomState() graph = StateGraph(custom_state)

Or using type hints for clarity

graph = StateGraphMyCustomState

Source code in pyagenity/graph/state_graph.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def __init__(
    self,
    state: StateT | None = None,
    context_manager: BaseContextManager[StateT] | None = None,
    publisher: BasePublisher | None = None,
    id_generator: BaseIDGenerator = DefaultIDGenerator(),
    container: InjectQ | None = None,
    thread_name_generator: Callable[[], str] | None = None,
):
    """Initialize a new StateGraph instance.

    Args:
        state: Initial state for the graph. If None, a default AgentState
            will be created.
        context_manager: Optional context manager for handling cross-node
            state operations and advanced state management patterns.
        dependency_container: Container for managing dependencies that can
            be injected into node functions. If None, a new empty container
            will be created.
        publisher: Publisher for emitting events during execution

    Note:
        START and END nodes are automatically added to the graph upon
        initialization and accept the full node signature including
        dependencies.

    Example:
        # Basic usage with default AgentState
        >>> graph = StateGraph()

        # With custom state
        >>> custom_state = MyCustomState()
        >>> graph = StateGraph(custom_state)

        # Or using type hints for clarity
        >>> graph = StateGraph[MyCustomState](MyCustomState())
    """
    logger.info("Initializing StateGraph")
    logger.debug(
        "StateGraph init with state=%s, context_manager=%s",
        type(state).__name__ if state else "default AgentState",
        type(context_manager).__name__ if context_manager else None,
    )

    # State handling
    self._state: StateT = state if state else AgentState()  # type: ignore[assignment]

    # Graph structure
    self.nodes: dict[str, Node] = {}
    self.edges: list[Edge] = []
    self.entry_point: str | None = None

    # Services
    self._publisher: BasePublisher | None = publisher
    self._id_generator: BaseIDGenerator = id_generator
    self._context_manager: BaseContextManager[StateT] | None = context_manager
    self.thread_name_generator = thread_name_generator
    # save container for dependency injection
    # if any container is passed then we will activate that
    # otherwise we can skip it and use the default one
    if container is None:
        self._container = InjectQ.get_instance()
        logger.debug("No container provided, using global singleton instance")
    else:
        logger.debug("Using provided dependency container instance")
        self._container = container
        self._container.activate()

    # Register task_manager, for async tasks
    # This will be used to run background tasks
    self._task_manager = BackgroundTaskManager()

    # now setup the graph
    self._setup()

    # Add START and END nodes (accept full node signature including dependencies)
    logger.debug("Adding default START and END nodes")
    self.nodes[START] = Node(START, lambda state, config, **deps: state, self._publisher)  # type: ignore
    self.nodes[END] = Node(END, lambda state, config, **deps: state, self._publisher)
    logger.debug("StateGraph initialized with %d nodes", len(self.nodes))
add_conditional_edges
add_conditional_edges(from_node, condition, path_map=None)

Add conditional routing between nodes based on runtime evaluation.

Creates dynamic routing logic where the next node is determined by evaluating a condition function against the current state. This enables complex branching logic, decision trees, and adaptive workflow routing.

Parameters:

Name Type Description Default
from_node str

Name of the source node where the condition is evaluated.

required
condition Callable

Callable function that takes the current AgentState and returns a value used for routing decisions. Should be deterministic and side-effect free.

required
path_map dict[str, str] | None

Optional dictionary mapping condition results to destination nodes. If provided, the condition's return value is looked up in this mapping. If None, the condition should return the destination node name directly.

None

Returns:

Name Type Description
StateGraph StateGraph

The graph instance for method chaining.

Raises:

Type Description
ValueError

If the condition function or path_map configuration is invalid.

Example
# Direct routing - condition returns node name
def route_by_priority(state):
    priority = state.data.get("priority", "normal")
    return "urgent_handler" if priority == "high" else "normal_handler"


graph.add_conditional_edges("classifier", route_by_priority)


# Mapped routing - condition result mapped to nodes
def get_category(state):
    return state.data.get("category", "default")


category_map = {
    "finance": "finance_processor",
    "legal": "legal_processor",
    "default": "general_processor",
}
graph.add_conditional_edges("categorizer", get_category, category_map)
Note

The condition function receives the current AgentState and should return consistent results for the same state. If using path_map, ensure the condition's return values match the map keys exactly.

Source code in pyagenity/graph/state_graph.py
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
def add_conditional_edges(
    self,
    from_node: str,
    condition: Callable,
    path_map: dict[str, str] | None = None,
) -> "StateGraph":
    """Add conditional routing between nodes based on runtime evaluation.

    Creates dynamic routing logic where the next node is determined by evaluating
    a condition function against the current state. This enables complex branching
    logic, decision trees, and adaptive workflow routing.

    Args:
        from_node: Name of the source node where the condition is evaluated.
        condition: Callable function that takes the current AgentState and returns
            a value used for routing decisions. Should be deterministic and
            side-effect free.
        path_map: Optional dictionary mapping condition results to destination nodes.
            If provided, the condition's return value is looked up in this mapping.
            If None, the condition should return the destination node name directly.

    Returns:
        StateGraph: The graph instance for method chaining.

    Raises:
        ValueError: If the condition function or path_map configuration is invalid.

    Example:
        ```python
        # Direct routing - condition returns node name
        def route_by_priority(state):
            priority = state.data.get("priority", "normal")
            return "urgent_handler" if priority == "high" else "normal_handler"


        graph.add_conditional_edges("classifier", route_by_priority)


        # Mapped routing - condition result mapped to nodes
        def get_category(state):
            return state.data.get("category", "default")


        category_map = {
            "finance": "finance_processor",
            "legal": "legal_processor",
            "default": "general_processor",
        }
        graph.add_conditional_edges("categorizer", get_category, category_map)
        ```

    Note:
        The condition function receives the current AgentState and should return
        consistent results for the same state. If using path_map, ensure the
        condition's return values match the map keys exactly.
    """
    """Add conditional edges from a node based on a condition function.

    Creates edges that are traversed based on the result of a condition
    function. The condition function receives the current state and should
    return a value that determines which edge to follow.

    Args:
        from_node: Name of the source node.
        condition: Function that evaluates the current state and returns
            a value to determine the next node.
        path_map: Optional mapping from condition results to target nodes.
            If provided, creates multiple conditional edges. If None,
            creates a single conditional edge.

    Returns:
        StateGraph: The graph instance for method chaining.

    Example:
        >>> def route_condition(state):
        ...     return "success" if state.success else "failure"
        >>> graph.add_conditional_edges(
        ...     "processor",
        ...     route_condition,
        ...     {"success": "next_step", "failure": "error_handler"},
        ... )
    """
    # Create edges based on possible returns from condition function
    logger.debug(
        "Node '%s' adding conditional edges with path_map: %s",
        from_node,
        path_map,
    )
    if path_map:
        logger.debug(
            "Node '%s' adding conditional edges with path_map: %s", from_node, path_map
        )
        for condition_result, target_node in path_map.items():
            edge = Edge(from_node, target_node, condition)
            edge.condition_result = condition_result
            self.edges.append(edge)
    else:
        # Single conditional edge
        logger.debug("Node '%s' adding single conditional edge", from_node)
        self.edges.append(Edge(from_node, "", condition))
    return self
add_edge
add_edge(from_node, to_node)

Add a static edge between two nodes.

Creates a direct connection from one node to another. If the source node is START, the target node becomes the entry point for the graph.

Parameters:

Name Type Description Default
from_node str

Name of the source node.

required
to_node str

Name of the target node.

required

Returns:

Name Type Description
StateGraph StateGraph

The graph instance for method chaining.

Example

graph.add_edge("node1", "node2") graph.add_edge(START, "entry_node") # Sets entry point

Source code in pyagenity/graph/state_graph.py
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
def add_edge(
    self,
    from_node: str,
    to_node: str,
) -> "StateGraph":
    """Add a static edge between two nodes.

    Creates a direct connection from one node to another. If the source
    node is START, the target node becomes the entry point for the graph.

    Args:
        from_node: Name of the source node.
        to_node: Name of the target node.

    Returns:
        StateGraph: The graph instance for method chaining.

    Example:
        >>> graph.add_edge("node1", "node2")
        >>> graph.add_edge(START, "entry_node")  # Sets entry point
    """
    logger.debug("Adding edge from '%s' to '%s'", from_node, to_node)
    # Set entry point if edge is from START
    if from_node == START:
        self.entry_point = to_node
        logger.info("Set entry point to '%s'", to_node)
    self.edges.append(Edge(from_node, to_node))
    logger.debug("Added edge (total edges: %d)", len(self.edges))
    return self
add_node
add_node(name_or_func, func=None)

Add a node to the graph.

This method supports two calling patterns: 1. Pass a callable as the first argument (name inferred from function name) 2. Pass a name string and callable/ToolNode as separate arguments

Parameters:

Name Type Description Default
name_or_func str | Callable

Either the node name (str) or a callable function. If callable, the function name will be used as the node name.

required
func Union[Callable, ToolNode, None]

The function or ToolNode to execute. Required if name_or_func is a string, ignored if name_or_func is callable.

None

Returns:

Name Type Description
StateGraph StateGraph

The graph instance for method chaining.

Raises:

Type Description
ValueError

If invalid arguments are provided.

Example
Method 1: Function name inferred

graph.add_node(my_function)

Method 2: Explicit name and function

graph.add_node("process", my_function)

Source code in pyagenity/graph/state_graph.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
def add_node(
    self,
    name_or_func: str | Callable,
    func: Union[Callable, "ToolNode", None] = None,
) -> "StateGraph":
    """Add a node to the graph.

    This method supports two calling patterns:
    1. Pass a callable as the first argument (name inferred from function name)
    2. Pass a name string and callable/ToolNode as separate arguments

    Args:
        name_or_func: Either the node name (str) or a callable function.
            If callable, the function name will be used as the node name.
        func: The function or ToolNode to execute. Required if name_or_func
            is a string, ignored if name_or_func is callable.

    Returns:
        StateGraph: The graph instance for method chaining.

    Raises:
        ValueError: If invalid arguments are provided.

    Example:
        >>> # Method 1: Function name inferred
        >>> graph.add_node(my_function)
        >>> # Method 2: Explicit name and function
        >>> graph.add_node("process", my_function)
    """
    if callable(name_or_func) and func is None:
        # Function passed as first argument
        name = name_or_func.__name__
        func = name_or_func
        logger.debug("Adding node '%s' with inferred name from function", name)
    elif isinstance(name_or_func, str) and (callable(func) or isinstance(func, ToolNode)):
        # Name and function passed separately
        name = name_or_func
        logger.debug(
            "Adding node '%s' with explicit name and %s",
            name,
            "ToolNode" if isinstance(func, ToolNode) else "callable",
        )
    else:
        error_msg = "Invalid arguments for add_node"
        logger.error(error_msg)
        raise ValueError(error_msg)

    self.nodes[name] = Node(name, func)
    logger.info("Added node '%s' to graph (total nodes: %d)", name, len(self.nodes))
    return self
compile
compile(checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())

Compile the graph for execution.

Parameters:

Name Type Description Default
checkpointer BaseCheckpointer[StateT] | None

Checkpointer for state persistence

None
store BaseStore | None

Store for additional data

None
debug

Enable debug mode

required
interrupt_before list[str] | None

List of node names to interrupt before execution

None
interrupt_after list[str] | None

List of node names to interrupt after execution

None
callback_manager CallbackManager

Callback manager for executing hooks

CallbackManager()
Source code in pyagenity/graph/state_graph.py
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
def compile(
    self,
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> "CompiledGraph[StateT]":
    """Compile the graph for execution.

    Args:
        checkpointer: Checkpointer for state persistence
        store: Store for additional data
        debug: Enable debug mode
        interrupt_before: List of node names to interrupt before execution
        interrupt_after: List of node names to interrupt after execution
        callback_manager: Callback manager for executing hooks
    """
    logger.info(
        "Compiling graph with %d nodes, %d edges, entry_point='%s'",
        len(self.nodes),
        len(self.edges),
        self.entry_point,
    )
    logger.debug(
        "Compile options: interrupt_before=%s, interrupt_after=%s",
        interrupt_before,
        interrupt_after,
    )

    if not self.entry_point:
        error_msg = "No entry point set. Use set_entry_point() or add an edge from START."
        logger.error(error_msg)
        raise GraphError(error_msg)

    # Validate graph structure
    logger.debug("Validating graph structure")
    self._validate_graph()
    logger.debug("Graph structure validated successfully")

    # Validate interrupt node names
    interrupt_before = interrupt_before or []
    interrupt_after = interrupt_after or []

    all_interrupt_nodes = set(interrupt_before + interrupt_after)
    invalid_nodes = all_interrupt_nodes - set(self.nodes.keys())
    if invalid_nodes:
        error_msg = f"Invalid interrupt nodes: {invalid_nodes}. Must be existing node names."
        logger.error(error_msg)
        raise GraphError(error_msg)

    self.compiled = True
    logger.info("Graph compilation completed successfully")
    # Import here to avoid circular import at module import time
    # Now update Checkpointer
    if checkpointer is None:
        from pyagenity.checkpointer import InMemoryCheckpointer

        checkpointer = InMemoryCheckpointer[StateT]()
        logger.debug("No checkpointer provided, using InMemoryCheckpointer")

    # Import the CompiledGraph class
    from .compiled_graph import CompiledGraph

    # Setup dependencies
    self._container.bind_instance(
        BaseCheckpointer,
        checkpointer,
        allow_concrete=True,
    )  # not null as we set default
    self._container.bind_instance(
        BaseStore,
        store,
        allow_none=True,
        allow_concrete=True,
    )
    self._container.bind_instance(
        CallbackManager,
        callback_manager,
        allow_concrete=True,
    )  # not null as we set default
    self._container.bind("interrupt_before", interrupt_before)
    self._container.bind("interrupt_after", interrupt_after)
    self._container.bind_instance(StateGraph, self)

    app = CompiledGraph(
        state=self._state,
        interrupt_after=interrupt_after,
        interrupt_before=interrupt_before,
        state_graph=self,
        checkpointer=checkpointer,
        publisher=self._publisher,
        store=store,
        task_manager=self._task_manager,
    )

    self._container.bind(CompiledGraph, app)
    # Compile the Graph, so it will optimize the dependency graph
    self._container.compile()
    return app
set_entry_point
set_entry_point(node_name)

Set the entry point for the graph.

Source code in pyagenity/graph/state_graph.py
381
382
383
384
385
386
def set_entry_point(self, node_name: str) -> "StateGraph":
    """Set the entry point for the graph."""
    self.entry_point = node_name
    self.add_edge(START, node_name)
    logger.info("Set entry point to '%s'", node_name)
    return self
Functions
tool_node

ToolNode package.

This package provides a modularized implementation of ToolNode. Public API:

  • ToolNode
  • HAS_FASTMCP, HAS_MCP

Backwards-compatible import path: from pyagenity.graph.tool_node import ToolNode

Modules:

Name Description
base

Tool execution node for PyAgenity graph workflows.

constants

Constants for ToolNode package.

deps

Dependency flags and optional imports for ToolNode.

executors

Executors for different tool providers and local functions.

schema

Schema utilities and local tool description building for ToolNode.

Classes:

Name Description
ToolNode

A unified registry and executor for callable functions from various tool providers.

Attributes:

Name Type Description
HAS_FASTMCP
HAS_MCP
Attributes
HAS_FASTMCP module-attribute
HAS_FASTMCP = True
HAS_MCP module-attribute
HAS_MCP = True
__all__ module-attribute
__all__ = ['HAS_FASTMCP', 'HAS_MCP', 'ToolNode']
Classes
ToolNode

Bases: SchemaMixin, LocalExecMixin, MCPMixin, ComposioMixin, LangChainMixin, KwargsResolverMixin

A unified registry and executor for callable functions from various tool providers.

ToolNode serves as the central hub for managing and executing tools from multiple sources: - Local Python functions - MCP (Model Context Protocol) tools - Composio adapter tools - LangChain tools

The class uses a mixin-based architecture to separate concerns and maintain clean integration with different tool providers. It provides both synchronous and asynchronous execution methods with comprehensive event publishing and error handling.

Attributes:

Name Type Description
_funcs dict[str, Callable]

Dictionary mapping function names to callable functions.

_client Client | None

Optional MCP client for remote tool execution.

_composio ComposioAdapter | None

Optional Composio adapter for external integrations.

_langchain Any | None

Optional LangChain adapter for LangChain tools.

mcp_tools list[str]

List of available MCP tool names.

composio_tools list[str]

List of available Composio tool names.

langchain_tools list[str]

List of available LangChain tool names.

Example
# Define local tools
def weather_tool(location: str) -> str:
    return f"Weather in {location}: Sunny, 25°C"


def calculator(a: int, b: int) -> int:
    return a + b


# Create ToolNode with local functions
tools = ToolNode([weather_tool, calculator])

# Execute a tool
result = await tools.invoke(
    name="weather_tool",
    args={"location": "New York"},
    tool_call_id="call_123",
    config={"user_id": "user1"},
    state=agent_state,
)

Methods:

Name Description
__init__

Initialize ToolNode with functions and optional tool adapters.

all_tools

Get all available tools from all configured providers.

all_tools_sync

Synchronously get all available tools from all configured providers.

get_local_tool

Generate OpenAI-compatible tool definitions for all registered local functions.

invoke

Execute a specific tool by name with the provided arguments.

stream

Execute a tool with streaming support, yielding incremental results.

Source code in pyagenity/graph/tool_node/base.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
class ToolNode(
    SchemaMixin,
    LocalExecMixin,
    MCPMixin,
    ComposioMixin,
    LangChainMixin,
    KwargsResolverMixin,
):
    """A unified registry and executor for callable functions from various tool providers.

    ToolNode serves as the central hub for managing and executing tools from multiple sources:
    - Local Python functions
    - MCP (Model Context Protocol) tools
    - Composio adapter tools
    - LangChain tools

    The class uses a mixin-based architecture to separate concerns and maintain clean
    integration with different tool providers. It provides both synchronous and asynchronous
    execution methods with comprehensive event publishing and error handling.

    Attributes:
        _funcs: Dictionary mapping function names to callable functions.
        _client: Optional MCP client for remote tool execution.
        _composio: Optional Composio adapter for external integrations.
        _langchain: Optional LangChain adapter for LangChain tools.
        mcp_tools: List of available MCP tool names.
        composio_tools: List of available Composio tool names.
        langchain_tools: List of available LangChain tool names.

    Example:
        ```python
        # Define local tools
        def weather_tool(location: str) -> str:
            return f"Weather in {location}: Sunny, 25°C"


        def calculator(a: int, b: int) -> int:
            return a + b


        # Create ToolNode with local functions
        tools = ToolNode([weather_tool, calculator])

        # Execute a tool
        result = await tools.invoke(
            name="weather_tool",
            args={"location": "New York"},
            tool_call_id="call_123",
            config={"user_id": "user1"},
            state=agent_state,
        )
        ```
    """

    def __init__(
        self,
        functions: t.Iterable[t.Callable],
        client: deps.Client | None = None,  # type: ignore
        composio_adapter: ComposioAdapter | None = None,
        langchain_adapter: t.Any | None = None,
    ) -> None:
        """Initialize ToolNode with functions and optional tool adapters.

        Args:
            functions: Iterable of callable functions to register as tools. Each function
                will be registered with its `__name__` as the tool identifier.
            client: Optional MCP (Model Context Protocol) client for remote tool access.
                Requires 'fastmcp' and 'mcp' packages to be installed.
            composio_adapter: Optional Composio adapter for external integrations and
                third-party API access.
            langchain_adapter: Optional LangChain adapter for accessing LangChain tools
                and integrations.

        Raises:
            ImportError: If MCP client is provided but required packages are not installed.
            TypeError: If any item in functions is not callable.

        Note:
            When using MCP client functionality, ensure you have installed the required
            dependencies with: `pip install pyagenity[mcp]`
        """
        logger.info("Initializing ToolNode with %d functions", len(list(functions)))

        if client is not None:
            # Read flags dynamically so tests can patch pyagenity.graph.tool_node.HAS_*
            mod = sys.modules.get("pyagenity.graph.tool_node")
            has_fastmcp = getattr(mod, "HAS_FASTMCP", deps.HAS_FASTMCP) if mod else deps.HAS_FASTMCP
            has_mcp = getattr(mod, "HAS_MCP", deps.HAS_MCP) if mod else deps.HAS_MCP

            if not has_fastmcp or not has_mcp:
                raise ImportError(
                    "MCP client functionality requires 'fastmcp' and 'mcp' packages. "
                    "Install with: pip install pyagenity[mcp]"
                )
            logger.debug("ToolNode initialized with MCP client")

        self._funcs: dict[str, t.Callable] = {}
        self._client: deps.Client | None = client  # type: ignore
        self._composio: ComposioAdapter | None = composio_adapter
        self._langchain: t.Any | None = langchain_adapter

        for fn in functions:
            if not callable(fn):
                raise TypeError("ToolNode only accepts callables")
            self._funcs[fn.__name__] = fn

        self.mcp_tools: list[str] = []
        self.composio_tools: list[str] = []
        self.langchain_tools: list[str] = []

    async def _all_tools_async(self) -> list[dict]:
        tools: list[dict] = self.get_local_tool()
        tools.extend(await self._get_mcp_tool())
        tools.extend(await self._get_composio_tools())
        tools.extend(await self._get_langchain_tools())
        return tools

    async def all_tools(self) -> list[dict]:
        """Get all available tools from all configured providers.

        Retrieves and combines tool definitions from local functions, MCP client,
        Composio adapter, and LangChain adapter. Each tool definition includes
        the function schema with parameters and descriptions.

        Returns:
            List of tool definitions in OpenAI function calling format. Each dict
            contains 'type': 'function' and 'function' with name, description,
            and parameters schema.

        Example:
            ```python
            tools = await tool_node.all_tools()
            # Returns:
            # [
            #   {
            #     "type": "function",
            #     "function": {
            #       "name": "weather_tool",
            #       "description": "Get weather information for a location",
            #       "parameters": {
            #         "type": "object",
            #         "properties": {
            #           "location": {"type": "string"}
            #         },
            #         "required": ["location"]
            #       }
            #     }
            #   }
            # ]
            ```
        """
        return await self._all_tools_async()

    def all_tools_sync(self) -> list[dict]:
        """Synchronously get all available tools from all configured providers.

        This is a synchronous wrapper around the async all_tools() method.
        It uses asyncio.run() to handle async operations from MCP, Composio,
        and LangChain adapters.

        Returns:
            List of tool definitions in OpenAI function calling format.

        Note:
            Prefer using the async `all_tools()` method when possible, especially
            in async contexts, to avoid potential event loop issues.
        """
        tools: list[dict] = self.get_local_tool()
        if self._client:
            result = asyncio.run(self._get_mcp_tool())
            if result:
                tools.extend(result)
        comp = asyncio.run(self._get_composio_tools())
        if comp:
            tools.extend(comp)
        lc = asyncio.run(self._get_langchain_tools())
        if lc:
            tools.extend(lc)
        return tools

    async def invoke(  # noqa: PLR0915
        self,
        name: str,
        args: dict,
        tool_call_id: str,
        config: dict[str, t.Any],
        state: AgentState,
        callback_manager: CallbackManager = Inject[CallbackManager],
    ) -> t.Any:
        """Execute a specific tool by name with the provided arguments.

        This method handles tool execution across all configured providers (local,
        MCP, Composio, LangChain) with comprehensive error handling, event publishing,
        and callback management.

        Args:
            name: The name of the tool to execute.
            args: Dictionary of arguments to pass to the tool function.
            tool_call_id: Unique identifier for this tool execution, used for
                tracking and result correlation.
            config: Configuration dictionary containing execution context and
                user-specific settings.
            state: Current agent state for context-aware tool execution.
            callback_manager: Manager for executing pre/post execution callbacks.
                Injected via dependency injection if not provided.

        Returns:
            Message object containing tool execution results, either successful
            output or error information with appropriate status indicators.

        Raises:
            The method handles all exceptions internally and returns error Messages
            rather than raising exceptions, ensuring robust execution flow.

        Example:
            ```python
            result = await tool_node.invoke(
                name="weather_tool",
                args={"location": "Paris", "units": "metric"},
                tool_call_id="call_abc123",
                config={"user_id": "user1", "session_id": "session1"},
                state=current_agent_state,
            )

            # result is a Message with tool execution results
            print(result.content)  # Tool output or error information
            ```

        Note:
            The method publishes execution events throughout the process for
            monitoring and debugging purposes. Tool execution is routed based
            on tool provider precedence: MCP → Composio → LangChain → Local.
        """
        logger.info("Executing tool '%s' with %d arguments", name, len(args))
        logger.debug("Tool arguments: %s", args)

        event = EventModel.default(
            config,
            data={"args": args, "tool_call_id": tool_call_id, "function_name": name},
            content_type=[ContentType.TOOL_CALL],
            event=Event.TOOL_EXECUTION,
        )
        event.node_name = name
        # Attach structured tool call block
        with contextlib.suppress(Exception):
            event.content_blocks = [ToolCallBlock(id=tool_call_id, name=name, args=args)]
        publish_event(event)

        if name in self.mcp_tools:
            event.metadata["is_mcp"] = True
            publish_event(event)
            res = await self._mcp_execute(
                name,
                args,
                tool_call_id,
                config,
                callback_manager,
            )
            event.data["message"] = res.model_dump()
            # Attach tool result block mirroring the tool output
            with contextlib.suppress(Exception):
                event.content_blocks = [
                    ToolResultBlock(call_id=tool_call_id, output=res.model_dump())
                ]
            event.event_type = EventType.END
            event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
            publish_event(event)
            return res

        if name in self.composio_tools:
            event.metadata["is_composio"] = True
            publish_event(event)
            res = await self._composio_execute(
                name,
                args,
                tool_call_id,
                config,
                callback_manager,
            )
            event.data["message"] = res.model_dump()
            with contextlib.suppress(Exception):
                event.content_blocks = [
                    ToolResultBlock(call_id=tool_call_id, output=res.model_dump())
                ]
            event.event_type = EventType.END
            event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
            publish_event(event)
            return res

        if name in self.langchain_tools:
            event.metadata["is_langchain"] = True
            publish_event(event)
            res = await self._langchain_execute(
                name,
                args,
                tool_call_id,
                config,
                callback_manager,
            )
            event.data["message"] = res.model_dump()
            with contextlib.suppress(Exception):
                event.content_blocks = [
                    ToolResultBlock(call_id=tool_call_id, output=res.model_dump())
                ]
            event.event_type = EventType.END
            event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
            publish_event(event)
            return res

        if name in self._funcs:
            event.metadata["is_mcp"] = False
            publish_event(event)
            res = await self._internal_execute(
                name,
                args,
                tool_call_id,
                config,
                state,
                callback_manager,
            )
            event.data["message"] = res.model_dump()
            with contextlib.suppress(Exception):
                event.content_blocks = [
                    ToolResultBlock(call_id=tool_call_id, output=res.model_dump())
                ]
            event.event_type = EventType.END
            event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
            publish_event(event)
            return res

        error_msg = f"Tool '{name}' not found."
        event.data["error"] = error_msg
        event.event_type = EventType.ERROR
        event.content_type = [ContentType.TOOL_RESULT, ContentType.ERROR]
        publish_event(event)
        return Message.tool_message(
            content=[
                ErrorBlock(message=error_msg),
                ToolResultBlock(
                    call_id=tool_call_id,
                    output=error_msg,
                    is_error=True,
                    status="failed",
                ),
            ],
        )

    async def stream(  # noqa: PLR0915
        self,
        name: str,
        args: dict,
        tool_call_id: str,
        config: dict[str, t.Any],
        state: AgentState,
        callback_manager: CallbackManager = Inject[CallbackManager],
    ) -> t.AsyncIterator[Message]:
        """Execute a tool with streaming support, yielding incremental results.

        Similar to invoke() but designed for tools that can provide streaming responses
        or when you want to process results as they become available. Currently,
        most tool providers return complete results, so this method typically yields
        a single Message with the full result.

        Args:
            name: The name of the tool to execute.
            args: Dictionary of arguments to pass to the tool function.
            tool_call_id: Unique identifier for this tool execution.
            config: Configuration dictionary containing execution context.
            state: Current agent state for context-aware tool execution.
            callback_manager: Manager for executing pre/post execution callbacks.

        Yields:
            Message objects containing tool execution results or status updates.
            For most tools, this will yield a single complete result Message.

        Example:
            ```python
            async for message in tool_node.stream(
                name="data_processor",
                args={"dataset": "large_data.csv"},
                tool_call_id="call_stream123",
                config={"user_id": "user1"},
                state=current_state,
            ):
                print(f"Received: {message.content}")
                # Process each streamed result
            ```

        Note:
            The streaming interface is designed for future expansion where tools
            may provide true streaming responses. Currently, it provides a
            consistent async iterator interface over tool results.
        """
        logger.info("Executing tool '%s' with %d arguments", name, len(args))
        logger.debug("Tool arguments: %s", args)
        event = EventModel.default(
            config,
            data={"args": args, "tool_call_id": tool_call_id, "function_name": name},
            content_type=[ContentType.TOOL_CALL],
            event=Event.TOOL_EXECUTION,
        )
        event.node_name = "ToolNode"
        with contextlib.suppress(Exception):
            event.content_blocks = [ToolCallBlock(id=tool_call_id, name=name, args=args)]

        if name in self.mcp_tools:
            event.metadata["function_type"] = "mcp"
            publish_event(event)
            message = await self._mcp_execute(
                name,
                args,
                tool_call_id,
                config,
                callback_manager,
            )
            event.data["message"] = message.model_dump()
            with contextlib.suppress(Exception):
                event.content_blocks = [
                    ToolResultBlock(call_id=tool_call_id, output=message.model_dump())
                ]
            event.event_type = EventType.END
            event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
            publish_event(event)
            yield message
            return

        if name in self.composio_tools:
            event.metadata["function_type"] = "composio"
            publish_event(event)
            message = await self._composio_execute(
                name,
                args,
                tool_call_id,
                config,
                callback_manager,
            )
            event.data["message"] = message.model_dump()
            with contextlib.suppress(Exception):
                event.content_blocks = [
                    ToolResultBlock(call_id=tool_call_id, output=message.model_dump())
                ]
            event.event_type = EventType.END
            event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
            publish_event(event)
            yield message
            return

        if name in self.langchain_tools:
            event.metadata["function_type"] = "langchain"
            publish_event(event)
            message = await self._langchain_execute(
                name,
                args,
                tool_call_id,
                config,
                callback_manager,
            )
            event.data["message"] = message.model_dump()
            with contextlib.suppress(Exception):
                event.content_blocks = [
                    ToolResultBlock(call_id=tool_call_id, output=message.model_dump())
                ]
            event.event_type = EventType.END
            event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
            publish_event(event)
            yield message
            return

        if name in self._funcs:
            event.metadata["function_type"] = "internal"
            publish_event(event)

            result = await self._internal_execute(
                name,
                args,
                tool_call_id,
                config,
                state,
                callback_manager,
            )
            event.data["message"] = result.model_dump()
            with contextlib.suppress(Exception):
                event.content_blocks = [
                    ToolResultBlock(call_id=tool_call_id, output=result.model_dump())
                ]
            event.event_type = EventType.END
            event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
            publish_event(event)
            yield result
            return

        error_msg = f"Tool '{name}' not found."
        event.data["error"] = error_msg
        event.event_type = EventType.ERROR
        event.content_type = [ContentType.TOOL_RESULT, ContentType.ERROR]
        publish_event(event)

        yield Message.tool_message(
            content=[
                ErrorBlock(message=error_msg),
                ToolResultBlock(
                    call_id=tool_call_id,
                    output=error_msg,
                    is_error=True,
                    status="failed",
                ),
            ],
        )
Attributes
composio_tools instance-attribute
composio_tools = []
langchain_tools instance-attribute
langchain_tools = []
mcp_tools instance-attribute
mcp_tools = []
Functions
__init__
__init__(functions, client=None, composio_adapter=None, langchain_adapter=None)

Initialize ToolNode with functions and optional tool adapters.

Parameters:

Name Type Description Default
functions Iterable[Callable]

Iterable of callable functions to register as tools. Each function will be registered with its __name__ as the tool identifier.

required
client Client | None

Optional MCP (Model Context Protocol) client for remote tool access. Requires 'fastmcp' and 'mcp' packages to be installed.

None
composio_adapter ComposioAdapter | None

Optional Composio adapter for external integrations and third-party API access.

None
langchain_adapter Any | None

Optional LangChain adapter for accessing LangChain tools and integrations.

None

Raises:

Type Description
ImportError

If MCP client is provided but required packages are not installed.

TypeError

If any item in functions is not callable.

Note

When using MCP client functionality, ensure you have installed the required dependencies with: pip install pyagenity[mcp]

Source code in pyagenity/graph/tool_node/base.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def __init__(
    self,
    functions: t.Iterable[t.Callable],
    client: deps.Client | None = None,  # type: ignore
    composio_adapter: ComposioAdapter | None = None,
    langchain_adapter: t.Any | None = None,
) -> None:
    """Initialize ToolNode with functions and optional tool adapters.

    Args:
        functions: Iterable of callable functions to register as tools. Each function
            will be registered with its `__name__` as the tool identifier.
        client: Optional MCP (Model Context Protocol) client for remote tool access.
            Requires 'fastmcp' and 'mcp' packages to be installed.
        composio_adapter: Optional Composio adapter for external integrations and
            third-party API access.
        langchain_adapter: Optional LangChain adapter for accessing LangChain tools
            and integrations.

    Raises:
        ImportError: If MCP client is provided but required packages are not installed.
        TypeError: If any item in functions is not callable.

    Note:
        When using MCP client functionality, ensure you have installed the required
        dependencies with: `pip install pyagenity[mcp]`
    """
    logger.info("Initializing ToolNode with %d functions", len(list(functions)))

    if client is not None:
        # Read flags dynamically so tests can patch pyagenity.graph.tool_node.HAS_*
        mod = sys.modules.get("pyagenity.graph.tool_node")
        has_fastmcp = getattr(mod, "HAS_FASTMCP", deps.HAS_FASTMCP) if mod else deps.HAS_FASTMCP
        has_mcp = getattr(mod, "HAS_MCP", deps.HAS_MCP) if mod else deps.HAS_MCP

        if not has_fastmcp or not has_mcp:
            raise ImportError(
                "MCP client functionality requires 'fastmcp' and 'mcp' packages. "
                "Install with: pip install pyagenity[mcp]"
            )
        logger.debug("ToolNode initialized with MCP client")

    self._funcs: dict[str, t.Callable] = {}
    self._client: deps.Client | None = client  # type: ignore
    self._composio: ComposioAdapter | None = composio_adapter
    self._langchain: t.Any | None = langchain_adapter

    for fn in functions:
        if not callable(fn):
            raise TypeError("ToolNode only accepts callables")
        self._funcs[fn.__name__] = fn

    self.mcp_tools: list[str] = []
    self.composio_tools: list[str] = []
    self.langchain_tools: list[str] = []
all_tools async
all_tools()

Get all available tools from all configured providers.

Retrieves and combines tool definitions from local functions, MCP client, Composio adapter, and LangChain adapter. Each tool definition includes the function schema with parameters and descriptions.

Returns:

Type Description
list[dict]

List of tool definitions in OpenAI function calling format. Each dict

list[dict]

contains 'type': 'function' and 'function' with name, description,

list[dict]

and parameters schema.

Example
tools = await tool_node.all_tools()
# Returns:
# [
#   {
#     "type": "function",
#     "function": {
#       "name": "weather_tool",
#       "description": "Get weather information for a location",
#       "parameters": {
#         "type": "object",
#         "properties": {
#           "location": {"type": "string"}
#         },
#         "required": ["location"]
#       }
#     }
#   }
# ]
Source code in pyagenity/graph/tool_node/base.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
async def all_tools(self) -> list[dict]:
    """Get all available tools from all configured providers.

    Retrieves and combines tool definitions from local functions, MCP client,
    Composio adapter, and LangChain adapter. Each tool definition includes
    the function schema with parameters and descriptions.

    Returns:
        List of tool definitions in OpenAI function calling format. Each dict
        contains 'type': 'function' and 'function' with name, description,
        and parameters schema.

    Example:
        ```python
        tools = await tool_node.all_tools()
        # Returns:
        # [
        #   {
        #     "type": "function",
        #     "function": {
        #       "name": "weather_tool",
        #       "description": "Get weather information for a location",
        #       "parameters": {
        #         "type": "object",
        #         "properties": {
        #           "location": {"type": "string"}
        #         },
        #         "required": ["location"]
        #       }
        #     }
        #   }
        # ]
        ```
    """
    return await self._all_tools_async()
all_tools_sync
all_tools_sync()

Synchronously get all available tools from all configured providers.

This is a synchronous wrapper around the async all_tools() method. It uses asyncio.run() to handle async operations from MCP, Composio, and LangChain adapters.

Returns:

Type Description
list[dict]

List of tool definitions in OpenAI function calling format.

Note

Prefer using the async all_tools() method when possible, especially in async contexts, to avoid potential event loop issues.

Source code in pyagenity/graph/tool_node/base.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def all_tools_sync(self) -> list[dict]:
    """Synchronously get all available tools from all configured providers.

    This is a synchronous wrapper around the async all_tools() method.
    It uses asyncio.run() to handle async operations from MCP, Composio,
    and LangChain adapters.

    Returns:
        List of tool definitions in OpenAI function calling format.

    Note:
        Prefer using the async `all_tools()` method when possible, especially
        in async contexts, to avoid potential event loop issues.
    """
    tools: list[dict] = self.get_local_tool()
    if self._client:
        result = asyncio.run(self._get_mcp_tool())
        if result:
            tools.extend(result)
    comp = asyncio.run(self._get_composio_tools())
    if comp:
        tools.extend(comp)
    lc = asyncio.run(self._get_langchain_tools())
    if lc:
        tools.extend(lc)
    return tools
get_local_tool
get_local_tool()

Generate OpenAI-compatible tool definitions for all registered local functions.

Inspects all registered functions in _funcs and automatically generates tool schemas by analyzing function signatures, type annotations, and docstrings. Excludes injectable parameters that are provided by the framework.

Returns:

Type Description
list[dict]

List of tool definitions in OpenAI function calling format. Each

list[dict]

definition includes the function name, description (from docstring),

list[dict]

and complete parameter schema with types and required fields.

Example

For a function:

def calculate(a: int, b: int, operation: str = "add") -> int:
    '''Perform arithmetic calculation.'''
    return a + b if operation == "add" else a - b

Returns:

[
    {
        "type": "function",
        "function": {
            "name": "calculate",
            "description": "Perform arithmetic calculation.",
            "parameters": {
                "type": "object",
                "properties": {
                    "a": {"type": "integer"},
                    "b": {"type": "integer"},
                    "operation": {"type": "string", "default": "add"},
                },
                "required": ["a", "b"],
            },
        },
    }
]

Note

Parameters listed in INJECTABLE_PARAMS (like 'state', 'config', 'tool_call_id') are automatically excluded from the generated schema as they are provided by the framework during execution.

Source code in pyagenity/graph/tool_node/schema.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def get_local_tool(self) -> list[dict]:
    """Generate OpenAI-compatible tool definitions for all registered local functions.

    Inspects all registered functions in _funcs and automatically generates
    tool schemas by analyzing function signatures, type annotations, and docstrings.
    Excludes injectable parameters that are provided by the framework.

    Returns:
        List of tool definitions in OpenAI function calling format. Each
        definition includes the function name, description (from docstring),
        and complete parameter schema with types and required fields.

    Example:
        For a function:
        ```python
        def calculate(a: int, b: int, operation: str = "add") -> int:
            '''Perform arithmetic calculation.'''
            return a + b if operation == "add" else a - b
        ```

        Returns:
        ```python
        [
            {
                "type": "function",
                "function": {
                    "name": "calculate",
                    "description": "Perform arithmetic calculation.",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "a": {"type": "integer"},
                            "b": {"type": "integer"},
                            "operation": {"type": "string", "default": "add"},
                        },
                        "required": ["a", "b"],
                    },
                },
            }
        ]
        ```

    Note:
        Parameters listed in INJECTABLE_PARAMS (like 'state', 'config',
        'tool_call_id') are automatically excluded from the generated schema
        as they are provided by the framework during execution.
    """
    tools: list[dict] = []
    for name, fn in self._funcs.items():
        sig = inspect.signature(fn)
        params_schema: dict = {"type": "object", "properties": {}, "required": []}

        for p_name, p in sig.parameters.items():
            if p.kind in (
                inspect.Parameter.VAR_POSITIONAL,
                inspect.Parameter.VAR_KEYWORD,
            ):
                continue

            if p_name in INJECTABLE_PARAMS:
                continue

            annotation = p.annotation if p.annotation is not inspect._empty else str
            prop = SchemaMixin._annotation_to_schema(annotation, p.default)
            params_schema["properties"][p_name] = prop

            if p.default is inspect._empty:
                params_schema["required"].append(p_name)

        if not params_schema["required"]:
            params_schema.pop("required")

        description = inspect.getdoc(fn) or "No description provided."

        # provider = getattr(fn, "_py_tool_provider", None)
        # tags = getattr(fn, "_py_tool_tags", None)
        # capabilities = getattr(fn, "_py_tool_capabilities", None)

        entry = {
            "type": "function",
            "function": {
                "name": name,
                "description": description,
                "parameters": params_schema,
            },
        }
        # meta: dict[str, t.Any] = {}
        # if provider:
        #     meta["provider"] = provider
        # if tags:
        #     meta["tags"] = tags
        # if capabilities:
        #     meta["capabilities"] = capabilities
        # if meta:
        #     entry["x-pyagenity"] = meta

        tools.append(entry)

    return tools
invoke async
invoke(name, args, tool_call_id, config, state, callback_manager=Inject[CallbackManager])

Execute a specific tool by name with the provided arguments.

This method handles tool execution across all configured providers (local, MCP, Composio, LangChain) with comprehensive error handling, event publishing, and callback management.

Parameters:

Name Type Description Default
name str

The name of the tool to execute.

required
args dict

Dictionary of arguments to pass to the tool function.

required
tool_call_id str

Unique identifier for this tool execution, used for tracking and result correlation.

required
config dict[str, Any]

Configuration dictionary containing execution context and user-specific settings.

required
state AgentState

Current agent state for context-aware tool execution.

required
callback_manager CallbackManager

Manager for executing pre/post execution callbacks. Injected via dependency injection if not provided.

Inject[CallbackManager]

Returns:

Type Description
Any

Message object containing tool execution results, either successful

Any

output or error information with appropriate status indicators.

Example
result = await tool_node.invoke(
    name="weather_tool",
    args={"location": "Paris", "units": "metric"},
    tool_call_id="call_abc123",
    config={"user_id": "user1", "session_id": "session1"},
    state=current_agent_state,
)

# result is a Message with tool execution results
print(result.content)  # Tool output or error information
Note

The method publishes execution events throughout the process for monitoring and debugging purposes. Tool execution is routed based on tool provider precedence: MCP → Composio → LangChain → Local.

Source code in pyagenity/graph/tool_node/base.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
async def invoke(  # noqa: PLR0915
    self,
    name: str,
    args: dict,
    tool_call_id: str,
    config: dict[str, t.Any],
    state: AgentState,
    callback_manager: CallbackManager = Inject[CallbackManager],
) -> t.Any:
    """Execute a specific tool by name with the provided arguments.

    This method handles tool execution across all configured providers (local,
    MCP, Composio, LangChain) with comprehensive error handling, event publishing,
    and callback management.

    Args:
        name: The name of the tool to execute.
        args: Dictionary of arguments to pass to the tool function.
        tool_call_id: Unique identifier for this tool execution, used for
            tracking and result correlation.
        config: Configuration dictionary containing execution context and
            user-specific settings.
        state: Current agent state for context-aware tool execution.
        callback_manager: Manager for executing pre/post execution callbacks.
            Injected via dependency injection if not provided.

    Returns:
        Message object containing tool execution results, either successful
        output or error information with appropriate status indicators.

    Raises:
        The method handles all exceptions internally and returns error Messages
        rather than raising exceptions, ensuring robust execution flow.

    Example:
        ```python
        result = await tool_node.invoke(
            name="weather_tool",
            args={"location": "Paris", "units": "metric"},
            tool_call_id="call_abc123",
            config={"user_id": "user1", "session_id": "session1"},
            state=current_agent_state,
        )

        # result is a Message with tool execution results
        print(result.content)  # Tool output or error information
        ```

    Note:
        The method publishes execution events throughout the process for
        monitoring and debugging purposes. Tool execution is routed based
        on tool provider precedence: MCP → Composio → LangChain → Local.
    """
    logger.info("Executing tool '%s' with %d arguments", name, len(args))
    logger.debug("Tool arguments: %s", args)

    event = EventModel.default(
        config,
        data={"args": args, "tool_call_id": tool_call_id, "function_name": name},
        content_type=[ContentType.TOOL_CALL],
        event=Event.TOOL_EXECUTION,
    )
    event.node_name = name
    # Attach structured tool call block
    with contextlib.suppress(Exception):
        event.content_blocks = [ToolCallBlock(id=tool_call_id, name=name, args=args)]
    publish_event(event)

    if name in self.mcp_tools:
        event.metadata["is_mcp"] = True
        publish_event(event)
        res = await self._mcp_execute(
            name,
            args,
            tool_call_id,
            config,
            callback_manager,
        )
        event.data["message"] = res.model_dump()
        # Attach tool result block mirroring the tool output
        with contextlib.suppress(Exception):
            event.content_blocks = [
                ToolResultBlock(call_id=tool_call_id, output=res.model_dump())
            ]
        event.event_type = EventType.END
        event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
        publish_event(event)
        return res

    if name in self.composio_tools:
        event.metadata["is_composio"] = True
        publish_event(event)
        res = await self._composio_execute(
            name,
            args,
            tool_call_id,
            config,
            callback_manager,
        )
        event.data["message"] = res.model_dump()
        with contextlib.suppress(Exception):
            event.content_blocks = [
                ToolResultBlock(call_id=tool_call_id, output=res.model_dump())
            ]
        event.event_type = EventType.END
        event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
        publish_event(event)
        return res

    if name in self.langchain_tools:
        event.metadata["is_langchain"] = True
        publish_event(event)
        res = await self._langchain_execute(
            name,
            args,
            tool_call_id,
            config,
            callback_manager,
        )
        event.data["message"] = res.model_dump()
        with contextlib.suppress(Exception):
            event.content_blocks = [
                ToolResultBlock(call_id=tool_call_id, output=res.model_dump())
            ]
        event.event_type = EventType.END
        event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
        publish_event(event)
        return res

    if name in self._funcs:
        event.metadata["is_mcp"] = False
        publish_event(event)
        res = await self._internal_execute(
            name,
            args,
            tool_call_id,
            config,
            state,
            callback_manager,
        )
        event.data["message"] = res.model_dump()
        with contextlib.suppress(Exception):
            event.content_blocks = [
                ToolResultBlock(call_id=tool_call_id, output=res.model_dump())
            ]
        event.event_type = EventType.END
        event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
        publish_event(event)
        return res

    error_msg = f"Tool '{name}' not found."
    event.data["error"] = error_msg
    event.event_type = EventType.ERROR
    event.content_type = [ContentType.TOOL_RESULT, ContentType.ERROR]
    publish_event(event)
    return Message.tool_message(
        content=[
            ErrorBlock(message=error_msg),
            ToolResultBlock(
                call_id=tool_call_id,
                output=error_msg,
                is_error=True,
                status="failed",
            ),
        ],
    )
stream async
stream(name, args, tool_call_id, config, state, callback_manager=Inject[CallbackManager])

Execute a tool with streaming support, yielding incremental results.

Similar to invoke() but designed for tools that can provide streaming responses or when you want to process results as they become available. Currently, most tool providers return complete results, so this method typically yields a single Message with the full result.

Parameters:

Name Type Description Default
name str

The name of the tool to execute.

required
args dict

Dictionary of arguments to pass to the tool function.

required
tool_call_id str

Unique identifier for this tool execution.

required
config dict[str, Any]

Configuration dictionary containing execution context.

required
state AgentState

Current agent state for context-aware tool execution.

required
callback_manager CallbackManager

Manager for executing pre/post execution callbacks.

Inject[CallbackManager]

Yields:

Type Description
AsyncIterator[Message]

Message objects containing tool execution results or status updates.

AsyncIterator[Message]

For most tools, this will yield a single complete result Message.

Example
async for message in tool_node.stream(
    name="data_processor",
    args={"dataset": "large_data.csv"},
    tool_call_id="call_stream123",
    config={"user_id": "user1"},
    state=current_state,
):
    print(f"Received: {message.content}")
    # Process each streamed result
Note

The streaming interface is designed for future expansion where tools may provide true streaming responses. Currently, it provides a consistent async iterator interface over tool results.

Source code in pyagenity/graph/tool_node/base.py
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
async def stream(  # noqa: PLR0915
    self,
    name: str,
    args: dict,
    tool_call_id: str,
    config: dict[str, t.Any],
    state: AgentState,
    callback_manager: CallbackManager = Inject[CallbackManager],
) -> t.AsyncIterator[Message]:
    """Execute a tool with streaming support, yielding incremental results.

    Similar to invoke() but designed for tools that can provide streaming responses
    or when you want to process results as they become available. Currently,
    most tool providers return complete results, so this method typically yields
    a single Message with the full result.

    Args:
        name: The name of the tool to execute.
        args: Dictionary of arguments to pass to the tool function.
        tool_call_id: Unique identifier for this tool execution.
        config: Configuration dictionary containing execution context.
        state: Current agent state for context-aware tool execution.
        callback_manager: Manager for executing pre/post execution callbacks.

    Yields:
        Message objects containing tool execution results or status updates.
        For most tools, this will yield a single complete result Message.

    Example:
        ```python
        async for message in tool_node.stream(
            name="data_processor",
            args={"dataset": "large_data.csv"},
            tool_call_id="call_stream123",
            config={"user_id": "user1"},
            state=current_state,
        ):
            print(f"Received: {message.content}")
            # Process each streamed result
        ```

    Note:
        The streaming interface is designed for future expansion where tools
        may provide true streaming responses. Currently, it provides a
        consistent async iterator interface over tool results.
    """
    logger.info("Executing tool '%s' with %d arguments", name, len(args))
    logger.debug("Tool arguments: %s", args)
    event = EventModel.default(
        config,
        data={"args": args, "tool_call_id": tool_call_id, "function_name": name},
        content_type=[ContentType.TOOL_CALL],
        event=Event.TOOL_EXECUTION,
    )
    event.node_name = "ToolNode"
    with contextlib.suppress(Exception):
        event.content_blocks = [ToolCallBlock(id=tool_call_id, name=name, args=args)]

    if name in self.mcp_tools:
        event.metadata["function_type"] = "mcp"
        publish_event(event)
        message = await self._mcp_execute(
            name,
            args,
            tool_call_id,
            config,
            callback_manager,
        )
        event.data["message"] = message.model_dump()
        with contextlib.suppress(Exception):
            event.content_blocks = [
                ToolResultBlock(call_id=tool_call_id, output=message.model_dump())
            ]
        event.event_type = EventType.END
        event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
        publish_event(event)
        yield message
        return

    if name in self.composio_tools:
        event.metadata["function_type"] = "composio"
        publish_event(event)
        message = await self._composio_execute(
            name,
            args,
            tool_call_id,
            config,
            callback_manager,
        )
        event.data["message"] = message.model_dump()
        with contextlib.suppress(Exception):
            event.content_blocks = [
                ToolResultBlock(call_id=tool_call_id, output=message.model_dump())
            ]
        event.event_type = EventType.END
        event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
        publish_event(event)
        yield message
        return

    if name in self.langchain_tools:
        event.metadata["function_type"] = "langchain"
        publish_event(event)
        message = await self._langchain_execute(
            name,
            args,
            tool_call_id,
            config,
            callback_manager,
        )
        event.data["message"] = message.model_dump()
        with contextlib.suppress(Exception):
            event.content_blocks = [
                ToolResultBlock(call_id=tool_call_id, output=message.model_dump())
            ]
        event.event_type = EventType.END
        event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
        publish_event(event)
        yield message
        return

    if name in self._funcs:
        event.metadata["function_type"] = "internal"
        publish_event(event)

        result = await self._internal_execute(
            name,
            args,
            tool_call_id,
            config,
            state,
            callback_manager,
        )
        event.data["message"] = result.model_dump()
        with contextlib.suppress(Exception):
            event.content_blocks = [
                ToolResultBlock(call_id=tool_call_id, output=result.model_dump())
            ]
        event.event_type = EventType.END
        event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
        publish_event(event)
        yield result
        return

    error_msg = f"Tool '{name}' not found."
    event.data["error"] = error_msg
    event.event_type = EventType.ERROR
    event.content_type = [ContentType.TOOL_RESULT, ContentType.ERROR]
    publish_event(event)

    yield Message.tool_message(
        content=[
            ErrorBlock(message=error_msg),
            ToolResultBlock(
                call_id=tool_call_id,
                output=error_msg,
                is_error=True,
                status="failed",
            ),
        ],
    )
Modules
base

Tool execution node for PyAgenity graph workflows.

This module provides the ToolNode class, which serves as a unified registry and executor for callable functions from various sources including local functions, MCP (Model Context Protocol) tools, Composio adapters, and LangChain tools. The ToolNode is designed with a modular architecture using mixins to handle different tool providers.

The ToolNode maintains compatibility with PyAgenity's dependency injection system and publishes execution events for monitoring and debugging purposes.

Typical usage example
def my_tool(query: str) -> str:
    return f"Result for: {query}"


tools = ToolNode([my_tool])
result = await tools.invoke("my_tool", {"query": "test"}, "call_id", config, state)

Classes:

Name Description
ToolNode

A unified registry and executor for callable functions from various tool providers.

Attributes:

Name Type Description
logger
Attributes
logger module-attribute
logger = getLogger(__name__)
Classes
ToolNode

Bases: SchemaMixin, LocalExecMixin, MCPMixin, ComposioMixin, LangChainMixin, KwargsResolverMixin

A unified registry and executor for callable functions from various tool providers.

ToolNode serves as the central hub for managing and executing tools from multiple sources: - Local Python functions - MCP (Model Context Protocol) tools - Composio adapter tools - LangChain tools

The class uses a mixin-based architecture to separate concerns and maintain clean integration with different tool providers. It provides both synchronous and asynchronous execution methods with comprehensive event publishing and error handling.

Attributes:

Name Type Description
_funcs dict[str, Callable]

Dictionary mapping function names to callable functions.

_client Client | None

Optional MCP client for remote tool execution.

_composio ComposioAdapter | None

Optional Composio adapter for external integrations.

_langchain Any | None

Optional LangChain adapter for LangChain tools.

mcp_tools list[str]

List of available MCP tool names.

composio_tools list[str]

List of available Composio tool names.

langchain_tools list[str]

List of available LangChain tool names.

Example
# Define local tools
def weather_tool(location: str) -> str:
    return f"Weather in {location}: Sunny, 25°C"


def calculator(a: int, b: int) -> int:
    return a + b


# Create ToolNode with local functions
tools = ToolNode([weather_tool, calculator])

# Execute a tool
result = await tools.invoke(
    name="weather_tool",
    args={"location": "New York"},
    tool_call_id="call_123",
    config={"user_id": "user1"},
    state=agent_state,
)

Methods:

Name Description
__init__

Initialize ToolNode with functions and optional tool adapters.

all_tools

Get all available tools from all configured providers.

all_tools_sync

Synchronously get all available tools from all configured providers.

get_local_tool

Generate OpenAI-compatible tool definitions for all registered local functions.

invoke

Execute a specific tool by name with the provided arguments.

stream

Execute a tool with streaming support, yielding incremental results.

Source code in pyagenity/graph/tool_node/base.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
class ToolNode(
    SchemaMixin,
    LocalExecMixin,
    MCPMixin,
    ComposioMixin,
    LangChainMixin,
    KwargsResolverMixin,
):
    """A unified registry and executor for callable functions from various tool providers.

    ToolNode serves as the central hub for managing and executing tools from multiple sources:
    - Local Python functions
    - MCP (Model Context Protocol) tools
    - Composio adapter tools
    - LangChain tools

    The class uses a mixin-based architecture to separate concerns and maintain clean
    integration with different tool providers. It provides both synchronous and asynchronous
    execution methods with comprehensive event publishing and error handling.

    Attributes:
        _funcs: Dictionary mapping function names to callable functions.
        _client: Optional MCP client for remote tool execution.
        _composio: Optional Composio adapter for external integrations.
        _langchain: Optional LangChain adapter for LangChain tools.
        mcp_tools: List of available MCP tool names.
        composio_tools: List of available Composio tool names.
        langchain_tools: List of available LangChain tool names.

    Example:
        ```python
        # Define local tools
        def weather_tool(location: str) -> str:
            return f"Weather in {location}: Sunny, 25°C"


        def calculator(a: int, b: int) -> int:
            return a + b


        # Create ToolNode with local functions
        tools = ToolNode([weather_tool, calculator])

        # Execute a tool
        result = await tools.invoke(
            name="weather_tool",
            args={"location": "New York"},
            tool_call_id="call_123",
            config={"user_id": "user1"},
            state=agent_state,
        )
        ```
    """

    def __init__(
        self,
        functions: t.Iterable[t.Callable],
        client: deps.Client | None = None,  # type: ignore
        composio_adapter: ComposioAdapter | None = None,
        langchain_adapter: t.Any | None = None,
    ) -> None:
        """Initialize ToolNode with functions and optional tool adapters.

        Args:
            functions: Iterable of callable functions to register as tools. Each function
                will be registered with its `__name__` as the tool identifier.
            client: Optional MCP (Model Context Protocol) client for remote tool access.
                Requires 'fastmcp' and 'mcp' packages to be installed.
            composio_adapter: Optional Composio adapter for external integrations and
                third-party API access.
            langchain_adapter: Optional LangChain adapter for accessing LangChain tools
                and integrations.

        Raises:
            ImportError: If MCP client is provided but required packages are not installed.
            TypeError: If any item in functions is not callable.

        Note:
            When using MCP client functionality, ensure you have installed the required
            dependencies with: `pip install pyagenity[mcp]`
        """
        logger.info("Initializing ToolNode with %d functions", len(list(functions)))

        if client is not None:
            # Read flags dynamically so tests can patch pyagenity.graph.tool_node.HAS_*
            mod = sys.modules.get("pyagenity.graph.tool_node")
            has_fastmcp = getattr(mod, "HAS_FASTMCP", deps.HAS_FASTMCP) if mod else deps.HAS_FASTMCP
            has_mcp = getattr(mod, "HAS_MCP", deps.HAS_MCP) if mod else deps.HAS_MCP

            if not has_fastmcp or not has_mcp:
                raise ImportError(
                    "MCP client functionality requires 'fastmcp' and 'mcp' packages. "
                    "Install with: pip install pyagenity[mcp]"
                )
            logger.debug("ToolNode initialized with MCP client")

        self._funcs: dict[str, t.Callable] = {}
        self._client: deps.Client | None = client  # type: ignore
        self._composio: ComposioAdapter | None = composio_adapter
        self._langchain: t.Any | None = langchain_adapter

        for fn in functions:
            if not callable(fn):
                raise TypeError("ToolNode only accepts callables")
            self._funcs[fn.__name__] = fn

        self.mcp_tools: list[str] = []
        self.composio_tools: list[str] = []
        self.langchain_tools: list[str] = []

    async def _all_tools_async(self) -> list[dict]:
        tools: list[dict] = self.get_local_tool()
        tools.extend(await self._get_mcp_tool())
        tools.extend(await self._get_composio_tools())
        tools.extend(await self._get_langchain_tools())
        return tools

    async def all_tools(self) -> list[dict]:
        """Get all available tools from all configured providers.

        Retrieves and combines tool definitions from local functions, MCP client,
        Composio adapter, and LangChain adapter. Each tool definition includes
        the function schema with parameters and descriptions.

        Returns:
            List of tool definitions in OpenAI function calling format. Each dict
            contains 'type': 'function' and 'function' with name, description,
            and parameters schema.

        Example:
            ```python
            tools = await tool_node.all_tools()
            # Returns:
            # [
            #   {
            #     "type": "function",
            #     "function": {
            #       "name": "weather_tool",
            #       "description": "Get weather information for a location",
            #       "parameters": {
            #         "type": "object",
            #         "properties": {
            #           "location": {"type": "string"}
            #         },
            #         "required": ["location"]
            #       }
            #     }
            #   }
            # ]
            ```
        """
        return await self._all_tools_async()

    def all_tools_sync(self) -> list[dict]:
        """Synchronously get all available tools from all configured providers.

        This is a synchronous wrapper around the async all_tools() method.
        It uses asyncio.run() to handle async operations from MCP, Composio,
        and LangChain adapters.

        Returns:
            List of tool definitions in OpenAI function calling format.

        Note:
            Prefer using the async `all_tools()` method when possible, especially
            in async contexts, to avoid potential event loop issues.
        """
        tools: list[dict] = self.get_local_tool()
        if self._client:
            result = asyncio.run(self._get_mcp_tool())
            if result:
                tools.extend(result)
        comp = asyncio.run(self._get_composio_tools())
        if comp:
            tools.extend(comp)
        lc = asyncio.run(self._get_langchain_tools())
        if lc:
            tools.extend(lc)
        return tools

    async def invoke(  # noqa: PLR0915
        self,
        name: str,
        args: dict,
        tool_call_id: str,
        config: dict[str, t.Any],
        state: AgentState,
        callback_manager: CallbackManager = Inject[CallbackManager],
    ) -> t.Any:
        """Execute a specific tool by name with the provided arguments.

        This method handles tool execution across all configured providers (local,
        MCP, Composio, LangChain) with comprehensive error handling, event publishing,
        and callback management.

        Args:
            name: The name of the tool to execute.
            args: Dictionary of arguments to pass to the tool function.
            tool_call_id: Unique identifier for this tool execution, used for
                tracking and result correlation.
            config: Configuration dictionary containing execution context and
                user-specific settings.
            state: Current agent state for context-aware tool execution.
            callback_manager: Manager for executing pre/post execution callbacks.
                Injected via dependency injection if not provided.

        Returns:
            Message object containing tool execution results, either successful
            output or error information with appropriate status indicators.

        Raises:
            The method handles all exceptions internally and returns error Messages
            rather than raising exceptions, ensuring robust execution flow.

        Example:
            ```python
            result = await tool_node.invoke(
                name="weather_tool",
                args={"location": "Paris", "units": "metric"},
                tool_call_id="call_abc123",
                config={"user_id": "user1", "session_id": "session1"},
                state=current_agent_state,
            )

            # result is a Message with tool execution results
            print(result.content)  # Tool output or error information
            ```

        Note:
            The method publishes execution events throughout the process for
            monitoring and debugging purposes. Tool execution is routed based
            on tool provider precedence: MCP → Composio → LangChain → Local.
        """
        logger.info("Executing tool '%s' with %d arguments", name, len(args))
        logger.debug("Tool arguments: %s", args)

        event = EventModel.default(
            config,
            data={"args": args, "tool_call_id": tool_call_id, "function_name": name},
            content_type=[ContentType.TOOL_CALL],
            event=Event.TOOL_EXECUTION,
        )
        event.node_name = name
        # Attach structured tool call block
        with contextlib.suppress(Exception):
            event.content_blocks = [ToolCallBlock(id=tool_call_id, name=name, args=args)]
        publish_event(event)

        if name in self.mcp_tools:
            event.metadata["is_mcp"] = True
            publish_event(event)
            res = await self._mcp_execute(
                name,
                args,
                tool_call_id,
                config,
                callback_manager,
            )
            event.data["message"] = res.model_dump()
            # Attach tool result block mirroring the tool output
            with contextlib.suppress(Exception):
                event.content_blocks = [
                    ToolResultBlock(call_id=tool_call_id, output=res.model_dump())
                ]
            event.event_type = EventType.END
            event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
            publish_event(event)
            return res

        if name in self.composio_tools:
            event.metadata["is_composio"] = True
            publish_event(event)
            res = await self._composio_execute(
                name,
                args,
                tool_call_id,
                config,
                callback_manager,
            )
            event.data["message"] = res.model_dump()
            with contextlib.suppress(Exception):
                event.content_blocks = [
                    ToolResultBlock(call_id=tool_call_id, output=res.model_dump())
                ]
            event.event_type = EventType.END
            event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
            publish_event(event)
            return res

        if name in self.langchain_tools:
            event.metadata["is_langchain"] = True
            publish_event(event)
            res = await self._langchain_execute(
                name,
                args,
                tool_call_id,
                config,
                callback_manager,
            )
            event.data["message"] = res.model_dump()
            with contextlib.suppress(Exception):
                event.content_blocks = [
                    ToolResultBlock(call_id=tool_call_id, output=res.model_dump())
                ]
            event.event_type = EventType.END
            event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
            publish_event(event)
            return res

        if name in self._funcs:
            event.metadata["is_mcp"] = False
            publish_event(event)
            res = await self._internal_execute(
                name,
                args,
                tool_call_id,
                config,
                state,
                callback_manager,
            )
            event.data["message"] = res.model_dump()
            with contextlib.suppress(Exception):
                event.content_blocks = [
                    ToolResultBlock(call_id=tool_call_id, output=res.model_dump())
                ]
            event.event_type = EventType.END
            event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
            publish_event(event)
            return res

        error_msg = f"Tool '{name}' not found."
        event.data["error"] = error_msg
        event.event_type = EventType.ERROR
        event.content_type = [ContentType.TOOL_RESULT, ContentType.ERROR]
        publish_event(event)
        return Message.tool_message(
            content=[
                ErrorBlock(message=error_msg),
                ToolResultBlock(
                    call_id=tool_call_id,
                    output=error_msg,
                    is_error=True,
                    status="failed",
                ),
            ],
        )

    async def stream(  # noqa: PLR0915
        self,
        name: str,
        args: dict,
        tool_call_id: str,
        config: dict[str, t.Any],
        state: AgentState,
        callback_manager: CallbackManager = Inject[CallbackManager],
    ) -> t.AsyncIterator[Message]:
        """Execute a tool with streaming support, yielding incremental results.

        Similar to invoke() but designed for tools that can provide streaming responses
        or when you want to process results as they become available. Currently,
        most tool providers return complete results, so this method typically yields
        a single Message with the full result.

        Args:
            name: The name of the tool to execute.
            args: Dictionary of arguments to pass to the tool function.
            tool_call_id: Unique identifier for this tool execution.
            config: Configuration dictionary containing execution context.
            state: Current agent state for context-aware tool execution.
            callback_manager: Manager for executing pre/post execution callbacks.

        Yields:
            Message objects containing tool execution results or status updates.
            For most tools, this will yield a single complete result Message.

        Example:
            ```python
            async for message in tool_node.stream(
                name="data_processor",
                args={"dataset": "large_data.csv"},
                tool_call_id="call_stream123",
                config={"user_id": "user1"},
                state=current_state,
            ):
                print(f"Received: {message.content}")
                # Process each streamed result
            ```

        Note:
            The streaming interface is designed for future expansion where tools
            may provide true streaming responses. Currently, it provides a
            consistent async iterator interface over tool results.
        """
        logger.info("Executing tool '%s' with %d arguments", name, len(args))
        logger.debug("Tool arguments: %s", args)
        event = EventModel.default(
            config,
            data={"args": args, "tool_call_id": tool_call_id, "function_name": name},
            content_type=[ContentType.TOOL_CALL],
            event=Event.TOOL_EXECUTION,
        )
        event.node_name = "ToolNode"
        with contextlib.suppress(Exception):
            event.content_blocks = [ToolCallBlock(id=tool_call_id, name=name, args=args)]

        if name in self.mcp_tools:
            event.metadata["function_type"] = "mcp"
            publish_event(event)
            message = await self._mcp_execute(
                name,
                args,
                tool_call_id,
                config,
                callback_manager,
            )
            event.data["message"] = message.model_dump()
            with contextlib.suppress(Exception):
                event.content_blocks = [
                    ToolResultBlock(call_id=tool_call_id, output=message.model_dump())
                ]
            event.event_type = EventType.END
            event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
            publish_event(event)
            yield message
            return

        if name in self.composio_tools:
            event.metadata["function_type"] = "composio"
            publish_event(event)
            message = await self._composio_execute(
                name,
                args,
                tool_call_id,
                config,
                callback_manager,
            )
            event.data["message"] = message.model_dump()
            with contextlib.suppress(Exception):
                event.content_blocks = [
                    ToolResultBlock(call_id=tool_call_id, output=message.model_dump())
                ]
            event.event_type = EventType.END
            event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
            publish_event(event)
            yield message
            return

        if name in self.langchain_tools:
            event.metadata["function_type"] = "langchain"
            publish_event(event)
            message = await self._langchain_execute(
                name,
                args,
                tool_call_id,
                config,
                callback_manager,
            )
            event.data["message"] = message.model_dump()
            with contextlib.suppress(Exception):
                event.content_blocks = [
                    ToolResultBlock(call_id=tool_call_id, output=message.model_dump())
                ]
            event.event_type = EventType.END
            event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
            publish_event(event)
            yield message
            return

        if name in self._funcs:
            event.metadata["function_type"] = "internal"
            publish_event(event)

            result = await self._internal_execute(
                name,
                args,
                tool_call_id,
                config,
                state,
                callback_manager,
            )
            event.data["message"] = result.model_dump()
            with contextlib.suppress(Exception):
                event.content_blocks = [
                    ToolResultBlock(call_id=tool_call_id, output=result.model_dump())
                ]
            event.event_type = EventType.END
            event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
            publish_event(event)
            yield result
            return

        error_msg = f"Tool '{name}' not found."
        event.data["error"] = error_msg
        event.event_type = EventType.ERROR
        event.content_type = [ContentType.TOOL_RESULT, ContentType.ERROR]
        publish_event(event)

        yield Message.tool_message(
            content=[
                ErrorBlock(message=error_msg),
                ToolResultBlock(
                    call_id=tool_call_id,
                    output=error_msg,
                    is_error=True,
                    status="failed",
                ),
            ],
        )
Attributes
composio_tools instance-attribute
composio_tools = []
langchain_tools instance-attribute
langchain_tools = []
mcp_tools instance-attribute
mcp_tools = []
Functions
__init__
__init__(functions, client=None, composio_adapter=None, langchain_adapter=None)

Initialize ToolNode with functions and optional tool adapters.

Parameters:

Name Type Description Default
functions Iterable[Callable]

Iterable of callable functions to register as tools. Each function will be registered with its __name__ as the tool identifier.

required
client Client | None

Optional MCP (Model Context Protocol) client for remote tool access. Requires 'fastmcp' and 'mcp' packages to be installed.

None
composio_adapter ComposioAdapter | None

Optional Composio adapter for external integrations and third-party API access.

None
langchain_adapter Any | None

Optional LangChain adapter for accessing LangChain tools and integrations.

None

Raises:

Type Description
ImportError

If MCP client is provided but required packages are not installed.

TypeError

If any item in functions is not callable.

Note

When using MCP client functionality, ensure you have installed the required dependencies with: pip install pyagenity[mcp]

Source code in pyagenity/graph/tool_node/base.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def __init__(
    self,
    functions: t.Iterable[t.Callable],
    client: deps.Client | None = None,  # type: ignore
    composio_adapter: ComposioAdapter | None = None,
    langchain_adapter: t.Any | None = None,
) -> None:
    """Initialize ToolNode with functions and optional tool adapters.

    Args:
        functions: Iterable of callable functions to register as tools. Each function
            will be registered with its `__name__` as the tool identifier.
        client: Optional MCP (Model Context Protocol) client for remote tool access.
            Requires 'fastmcp' and 'mcp' packages to be installed.
        composio_adapter: Optional Composio adapter for external integrations and
            third-party API access.
        langchain_adapter: Optional LangChain adapter for accessing LangChain tools
            and integrations.

    Raises:
        ImportError: If MCP client is provided but required packages are not installed.
        TypeError: If any item in functions is not callable.

    Note:
        When using MCP client functionality, ensure you have installed the required
        dependencies with: `pip install pyagenity[mcp]`
    """
    logger.info("Initializing ToolNode with %d functions", len(list(functions)))

    if client is not None:
        # Read flags dynamically so tests can patch pyagenity.graph.tool_node.HAS_*
        mod = sys.modules.get("pyagenity.graph.tool_node")
        has_fastmcp = getattr(mod, "HAS_FASTMCP", deps.HAS_FASTMCP) if mod else deps.HAS_FASTMCP
        has_mcp = getattr(mod, "HAS_MCP", deps.HAS_MCP) if mod else deps.HAS_MCP

        if not has_fastmcp or not has_mcp:
            raise ImportError(
                "MCP client functionality requires 'fastmcp' and 'mcp' packages. "
                "Install with: pip install pyagenity[mcp]"
            )
        logger.debug("ToolNode initialized with MCP client")

    self._funcs: dict[str, t.Callable] = {}
    self._client: deps.Client | None = client  # type: ignore
    self._composio: ComposioAdapter | None = composio_adapter
    self._langchain: t.Any | None = langchain_adapter

    for fn in functions:
        if not callable(fn):
            raise TypeError("ToolNode only accepts callables")
        self._funcs[fn.__name__] = fn

    self.mcp_tools: list[str] = []
    self.composio_tools: list[str] = []
    self.langchain_tools: list[str] = []
all_tools async
all_tools()

Get all available tools from all configured providers.

Retrieves and combines tool definitions from local functions, MCP client, Composio adapter, and LangChain adapter. Each tool definition includes the function schema with parameters and descriptions.

Returns:

Type Description
list[dict]

List of tool definitions in OpenAI function calling format. Each dict

list[dict]

contains 'type': 'function' and 'function' with name, description,

list[dict]

and parameters schema.

Example
tools = await tool_node.all_tools()
# Returns:
# [
#   {
#     "type": "function",
#     "function": {
#       "name": "weather_tool",
#       "description": "Get weather information for a location",
#       "parameters": {
#         "type": "object",
#         "properties": {
#           "location": {"type": "string"}
#         },
#         "required": ["location"]
#       }
#     }
#   }
# ]
Source code in pyagenity/graph/tool_node/base.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
async def all_tools(self) -> list[dict]:
    """Get all available tools from all configured providers.

    Retrieves and combines tool definitions from local functions, MCP client,
    Composio adapter, and LangChain adapter. Each tool definition includes
    the function schema with parameters and descriptions.

    Returns:
        List of tool definitions in OpenAI function calling format. Each dict
        contains 'type': 'function' and 'function' with name, description,
        and parameters schema.

    Example:
        ```python
        tools = await tool_node.all_tools()
        # Returns:
        # [
        #   {
        #     "type": "function",
        #     "function": {
        #       "name": "weather_tool",
        #       "description": "Get weather information for a location",
        #       "parameters": {
        #         "type": "object",
        #         "properties": {
        #           "location": {"type": "string"}
        #         },
        #         "required": ["location"]
        #       }
        #     }
        #   }
        # ]
        ```
    """
    return await self._all_tools_async()
all_tools_sync
all_tools_sync()

Synchronously get all available tools from all configured providers.

This is a synchronous wrapper around the async all_tools() method. It uses asyncio.run() to handle async operations from MCP, Composio, and LangChain adapters.

Returns:

Type Description
list[dict]

List of tool definitions in OpenAI function calling format.

Note

Prefer using the async all_tools() method when possible, especially in async contexts, to avoid potential event loop issues.

Source code in pyagenity/graph/tool_node/base.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def all_tools_sync(self) -> list[dict]:
    """Synchronously get all available tools from all configured providers.

    This is a synchronous wrapper around the async all_tools() method.
    It uses asyncio.run() to handle async operations from MCP, Composio,
    and LangChain adapters.

    Returns:
        List of tool definitions in OpenAI function calling format.

    Note:
        Prefer using the async `all_tools()` method when possible, especially
        in async contexts, to avoid potential event loop issues.
    """
    tools: list[dict] = self.get_local_tool()
    if self._client:
        result = asyncio.run(self._get_mcp_tool())
        if result:
            tools.extend(result)
    comp = asyncio.run(self._get_composio_tools())
    if comp:
        tools.extend(comp)
    lc = asyncio.run(self._get_langchain_tools())
    if lc:
        tools.extend(lc)
    return tools
get_local_tool
get_local_tool()

Generate OpenAI-compatible tool definitions for all registered local functions.

Inspects all registered functions in _funcs and automatically generates tool schemas by analyzing function signatures, type annotations, and docstrings. Excludes injectable parameters that are provided by the framework.

Returns:

Type Description
list[dict]

List of tool definitions in OpenAI function calling format. Each

list[dict]

definition includes the function name, description (from docstring),

list[dict]

and complete parameter schema with types and required fields.

Example

For a function:

def calculate(a: int, b: int, operation: str = "add") -> int:
    '''Perform arithmetic calculation.'''
    return a + b if operation == "add" else a - b

Returns:

[
    {
        "type": "function",
        "function": {
            "name": "calculate",
            "description": "Perform arithmetic calculation.",
            "parameters": {
                "type": "object",
                "properties": {
                    "a": {"type": "integer"},
                    "b": {"type": "integer"},
                    "operation": {"type": "string", "default": "add"},
                },
                "required": ["a", "b"],
            },
        },
    }
]

Note

Parameters listed in INJECTABLE_PARAMS (like 'state', 'config', 'tool_call_id') are automatically excluded from the generated schema as they are provided by the framework during execution.

Source code in pyagenity/graph/tool_node/schema.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def get_local_tool(self) -> list[dict]:
    """Generate OpenAI-compatible tool definitions for all registered local functions.

    Inspects all registered functions in _funcs and automatically generates
    tool schemas by analyzing function signatures, type annotations, and docstrings.
    Excludes injectable parameters that are provided by the framework.

    Returns:
        List of tool definitions in OpenAI function calling format. Each
        definition includes the function name, description (from docstring),
        and complete parameter schema with types and required fields.

    Example:
        For a function:
        ```python
        def calculate(a: int, b: int, operation: str = "add") -> int:
            '''Perform arithmetic calculation.'''
            return a + b if operation == "add" else a - b
        ```

        Returns:
        ```python
        [
            {
                "type": "function",
                "function": {
                    "name": "calculate",
                    "description": "Perform arithmetic calculation.",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "a": {"type": "integer"},
                            "b": {"type": "integer"},
                            "operation": {"type": "string", "default": "add"},
                        },
                        "required": ["a", "b"],
                    },
                },
            }
        ]
        ```

    Note:
        Parameters listed in INJECTABLE_PARAMS (like 'state', 'config',
        'tool_call_id') are automatically excluded from the generated schema
        as they are provided by the framework during execution.
    """
    tools: list[dict] = []
    for name, fn in self._funcs.items():
        sig = inspect.signature(fn)
        params_schema: dict = {"type": "object", "properties": {}, "required": []}

        for p_name, p in sig.parameters.items():
            if p.kind in (
                inspect.Parameter.VAR_POSITIONAL,
                inspect.Parameter.VAR_KEYWORD,
            ):
                continue

            if p_name in INJECTABLE_PARAMS:
                continue

            annotation = p.annotation if p.annotation is not inspect._empty else str
            prop = SchemaMixin._annotation_to_schema(annotation, p.default)
            params_schema["properties"][p_name] = prop

            if p.default is inspect._empty:
                params_schema["required"].append(p_name)

        if not params_schema["required"]:
            params_schema.pop("required")

        description = inspect.getdoc(fn) or "No description provided."

        # provider = getattr(fn, "_py_tool_provider", None)
        # tags = getattr(fn, "_py_tool_tags", None)
        # capabilities = getattr(fn, "_py_tool_capabilities", None)

        entry = {
            "type": "function",
            "function": {
                "name": name,
                "description": description,
                "parameters": params_schema,
            },
        }
        # meta: dict[str, t.Any] = {}
        # if provider:
        #     meta["provider"] = provider
        # if tags:
        #     meta["tags"] = tags
        # if capabilities:
        #     meta["capabilities"] = capabilities
        # if meta:
        #     entry["x-pyagenity"] = meta

        tools.append(entry)

    return tools
invoke async
invoke(name, args, tool_call_id, config, state, callback_manager=Inject[CallbackManager])

Execute a specific tool by name with the provided arguments.

This method handles tool execution across all configured providers (local, MCP, Composio, LangChain) with comprehensive error handling, event publishing, and callback management.

Parameters:

Name Type Description Default
name str

The name of the tool to execute.

required
args dict

Dictionary of arguments to pass to the tool function.

required
tool_call_id str

Unique identifier for this tool execution, used for tracking and result correlation.

required
config dict[str, Any]

Configuration dictionary containing execution context and user-specific settings.

required
state AgentState

Current agent state for context-aware tool execution.

required
callback_manager CallbackManager

Manager for executing pre/post execution callbacks. Injected via dependency injection if not provided.

Inject[CallbackManager]

Returns:

Type Description
Any

Message object containing tool execution results, either successful

Any

output or error information with appropriate status indicators.

Example
result = await tool_node.invoke(
    name="weather_tool",
    args={"location": "Paris", "units": "metric"},
    tool_call_id="call_abc123",
    config={"user_id": "user1", "session_id": "session1"},
    state=current_agent_state,
)

# result is a Message with tool execution results
print(result.content)  # Tool output or error information
Note

The method publishes execution events throughout the process for monitoring and debugging purposes. Tool execution is routed based on tool provider precedence: MCP → Composio → LangChain → Local.

Source code in pyagenity/graph/tool_node/base.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
async def invoke(  # noqa: PLR0915
    self,
    name: str,
    args: dict,
    tool_call_id: str,
    config: dict[str, t.Any],
    state: AgentState,
    callback_manager: CallbackManager = Inject[CallbackManager],
) -> t.Any:
    """Execute a specific tool by name with the provided arguments.

    This method handles tool execution across all configured providers (local,
    MCP, Composio, LangChain) with comprehensive error handling, event publishing,
    and callback management.

    Args:
        name: The name of the tool to execute.
        args: Dictionary of arguments to pass to the tool function.
        tool_call_id: Unique identifier for this tool execution, used for
            tracking and result correlation.
        config: Configuration dictionary containing execution context and
            user-specific settings.
        state: Current agent state for context-aware tool execution.
        callback_manager: Manager for executing pre/post execution callbacks.
            Injected via dependency injection if not provided.

    Returns:
        Message object containing tool execution results, either successful
        output or error information with appropriate status indicators.

    Raises:
        The method handles all exceptions internally and returns error Messages
        rather than raising exceptions, ensuring robust execution flow.

    Example:
        ```python
        result = await tool_node.invoke(
            name="weather_tool",
            args={"location": "Paris", "units": "metric"},
            tool_call_id="call_abc123",
            config={"user_id": "user1", "session_id": "session1"},
            state=current_agent_state,
        )

        # result is a Message with tool execution results
        print(result.content)  # Tool output or error information
        ```

    Note:
        The method publishes execution events throughout the process for
        monitoring and debugging purposes. Tool execution is routed based
        on tool provider precedence: MCP → Composio → LangChain → Local.
    """
    logger.info("Executing tool '%s' with %d arguments", name, len(args))
    logger.debug("Tool arguments: %s", args)

    event = EventModel.default(
        config,
        data={"args": args, "tool_call_id": tool_call_id, "function_name": name},
        content_type=[ContentType.TOOL_CALL],
        event=Event.TOOL_EXECUTION,
    )
    event.node_name = name
    # Attach structured tool call block
    with contextlib.suppress(Exception):
        event.content_blocks = [ToolCallBlock(id=tool_call_id, name=name, args=args)]
    publish_event(event)

    if name in self.mcp_tools:
        event.metadata["is_mcp"] = True
        publish_event(event)
        res = await self._mcp_execute(
            name,
            args,
            tool_call_id,
            config,
            callback_manager,
        )
        event.data["message"] = res.model_dump()
        # Attach tool result block mirroring the tool output
        with contextlib.suppress(Exception):
            event.content_blocks = [
                ToolResultBlock(call_id=tool_call_id, output=res.model_dump())
            ]
        event.event_type = EventType.END
        event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
        publish_event(event)
        return res

    if name in self.composio_tools:
        event.metadata["is_composio"] = True
        publish_event(event)
        res = await self._composio_execute(
            name,
            args,
            tool_call_id,
            config,
            callback_manager,
        )
        event.data["message"] = res.model_dump()
        with contextlib.suppress(Exception):
            event.content_blocks = [
                ToolResultBlock(call_id=tool_call_id, output=res.model_dump())
            ]
        event.event_type = EventType.END
        event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
        publish_event(event)
        return res

    if name in self.langchain_tools:
        event.metadata["is_langchain"] = True
        publish_event(event)
        res = await self._langchain_execute(
            name,
            args,
            tool_call_id,
            config,
            callback_manager,
        )
        event.data["message"] = res.model_dump()
        with contextlib.suppress(Exception):
            event.content_blocks = [
                ToolResultBlock(call_id=tool_call_id, output=res.model_dump())
            ]
        event.event_type = EventType.END
        event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
        publish_event(event)
        return res

    if name in self._funcs:
        event.metadata["is_mcp"] = False
        publish_event(event)
        res = await self._internal_execute(
            name,
            args,
            tool_call_id,
            config,
            state,
            callback_manager,
        )
        event.data["message"] = res.model_dump()
        with contextlib.suppress(Exception):
            event.content_blocks = [
                ToolResultBlock(call_id=tool_call_id, output=res.model_dump())
            ]
        event.event_type = EventType.END
        event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
        publish_event(event)
        return res

    error_msg = f"Tool '{name}' not found."
    event.data["error"] = error_msg
    event.event_type = EventType.ERROR
    event.content_type = [ContentType.TOOL_RESULT, ContentType.ERROR]
    publish_event(event)
    return Message.tool_message(
        content=[
            ErrorBlock(message=error_msg),
            ToolResultBlock(
                call_id=tool_call_id,
                output=error_msg,
                is_error=True,
                status="failed",
            ),
        ],
    )
stream async
stream(name, args, tool_call_id, config, state, callback_manager=Inject[CallbackManager])

Execute a tool with streaming support, yielding incremental results.

Similar to invoke() but designed for tools that can provide streaming responses or when you want to process results as they become available. Currently, most tool providers return complete results, so this method typically yields a single Message with the full result.

Parameters:

Name Type Description Default
name str

The name of the tool to execute.

required
args dict

Dictionary of arguments to pass to the tool function.

required
tool_call_id str

Unique identifier for this tool execution.

required
config dict[str, Any]

Configuration dictionary containing execution context.

required
state AgentState

Current agent state for context-aware tool execution.

required
callback_manager CallbackManager

Manager for executing pre/post execution callbacks.

Inject[CallbackManager]

Yields:

Type Description
AsyncIterator[Message]

Message objects containing tool execution results or status updates.

AsyncIterator[Message]

For most tools, this will yield a single complete result Message.

Example
async for message in tool_node.stream(
    name="data_processor",
    args={"dataset": "large_data.csv"},
    tool_call_id="call_stream123",
    config={"user_id": "user1"},
    state=current_state,
):
    print(f"Received: {message.content}")
    # Process each streamed result
Note

The streaming interface is designed for future expansion where tools may provide true streaming responses. Currently, it provides a consistent async iterator interface over tool results.

Source code in pyagenity/graph/tool_node/base.py
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
async def stream(  # noqa: PLR0915
    self,
    name: str,
    args: dict,
    tool_call_id: str,
    config: dict[str, t.Any],
    state: AgentState,
    callback_manager: CallbackManager = Inject[CallbackManager],
) -> t.AsyncIterator[Message]:
    """Execute a tool with streaming support, yielding incremental results.

    Similar to invoke() but designed for tools that can provide streaming responses
    or when you want to process results as they become available. Currently,
    most tool providers return complete results, so this method typically yields
    a single Message with the full result.

    Args:
        name: The name of the tool to execute.
        args: Dictionary of arguments to pass to the tool function.
        tool_call_id: Unique identifier for this tool execution.
        config: Configuration dictionary containing execution context.
        state: Current agent state for context-aware tool execution.
        callback_manager: Manager for executing pre/post execution callbacks.

    Yields:
        Message objects containing tool execution results or status updates.
        For most tools, this will yield a single complete result Message.

    Example:
        ```python
        async for message in tool_node.stream(
            name="data_processor",
            args={"dataset": "large_data.csv"},
            tool_call_id="call_stream123",
            config={"user_id": "user1"},
            state=current_state,
        ):
            print(f"Received: {message.content}")
            # Process each streamed result
        ```

    Note:
        The streaming interface is designed for future expansion where tools
        may provide true streaming responses. Currently, it provides a
        consistent async iterator interface over tool results.
    """
    logger.info("Executing tool '%s' with %d arguments", name, len(args))
    logger.debug("Tool arguments: %s", args)
    event = EventModel.default(
        config,
        data={"args": args, "tool_call_id": tool_call_id, "function_name": name},
        content_type=[ContentType.TOOL_CALL],
        event=Event.TOOL_EXECUTION,
    )
    event.node_name = "ToolNode"
    with contextlib.suppress(Exception):
        event.content_blocks = [ToolCallBlock(id=tool_call_id, name=name, args=args)]

    if name in self.mcp_tools:
        event.metadata["function_type"] = "mcp"
        publish_event(event)
        message = await self._mcp_execute(
            name,
            args,
            tool_call_id,
            config,
            callback_manager,
        )
        event.data["message"] = message.model_dump()
        with contextlib.suppress(Exception):
            event.content_blocks = [
                ToolResultBlock(call_id=tool_call_id, output=message.model_dump())
            ]
        event.event_type = EventType.END
        event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
        publish_event(event)
        yield message
        return

    if name in self.composio_tools:
        event.metadata["function_type"] = "composio"
        publish_event(event)
        message = await self._composio_execute(
            name,
            args,
            tool_call_id,
            config,
            callback_manager,
        )
        event.data["message"] = message.model_dump()
        with contextlib.suppress(Exception):
            event.content_blocks = [
                ToolResultBlock(call_id=tool_call_id, output=message.model_dump())
            ]
        event.event_type = EventType.END
        event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
        publish_event(event)
        yield message
        return

    if name in self.langchain_tools:
        event.metadata["function_type"] = "langchain"
        publish_event(event)
        message = await self._langchain_execute(
            name,
            args,
            tool_call_id,
            config,
            callback_manager,
        )
        event.data["message"] = message.model_dump()
        with contextlib.suppress(Exception):
            event.content_blocks = [
                ToolResultBlock(call_id=tool_call_id, output=message.model_dump())
            ]
        event.event_type = EventType.END
        event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
        publish_event(event)
        yield message
        return

    if name in self._funcs:
        event.metadata["function_type"] = "internal"
        publish_event(event)

        result = await self._internal_execute(
            name,
            args,
            tool_call_id,
            config,
            state,
            callback_manager,
        )
        event.data["message"] = result.model_dump()
        with contextlib.suppress(Exception):
            event.content_blocks = [
                ToolResultBlock(call_id=tool_call_id, output=result.model_dump())
            ]
        event.event_type = EventType.END
        event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
        publish_event(event)
        yield result
        return

    error_msg = f"Tool '{name}' not found."
    event.data["error"] = error_msg
    event.event_type = EventType.ERROR
    event.content_type = [ContentType.TOOL_RESULT, ContentType.ERROR]
    publish_event(event)

    yield Message.tool_message(
        content=[
            ErrorBlock(message=error_msg),
            ToolResultBlock(
                call_id=tool_call_id,
                output=error_msg,
                is_error=True,
                status="failed",
            ),
        ],
    )
Functions Modules
constants

Constants for ToolNode package.

This module defines constants used throughout the ToolNode implementation, particularly parameter names that are automatically injected by the PyAgenity framework during tool execution. These parameters are excluded from tool schema generation since they are provided by the execution context.

The constants are separated into their own module to avoid circular imports and maintain a clean public API.

Parameter names that are automatically injected during tool execution.

These parameters are provided by the PyAgenity framework and should be excluded from tool schema generation. They represent execution context and framework services that are available to tool functions but not provided by the user.

Parameters:

Name Type Description Default
tool_call_id

Unique identifier for the current tool execution.

required
state

Current AgentState instance for context-aware execution.

required
config

Configuration dictionary with execution settings.

required
generated_id

Framework-generated identifier for various purposes.

required
context_manager

BaseContextManager instance for cross-node operations.

required
publisher

BasePublisher instance for event publishing.

required
checkpointer

BaseCheckpointer instance for state persistence.

required
store

BaseStore instance for data storage operations.

required
Note

Tool functions can declare these parameters in their signatures to receive the corresponding services, but they should not be included in the tool schema since they're not user-provided arguments.

Attributes:

Name Type Description
INJECTABLE_PARAMS
Attributes
INJECTABLE_PARAMS module-attribute
INJECTABLE_PARAMS = {'tool_call_id', 'state', 'config', 'generated_id', 'context_manager', 'publisher', 'checkpointer', 'store'}
deps

Dependency flags and optional imports for ToolNode.

This module manages optional third-party dependencies for the ToolNode implementation, providing clean import handling and feature flags. It isolates optional imports to prevent ImportError cascades when optional dependencies are not installed.

The module handles two main optional dependency groups: 1. MCP (Model Context Protocol) support via 'fastmcp' and 'mcp' packages 2. Future extensibility for other optional tool providers

By centralizing optional imports here, other modules can safely import the flags and types without triggering ImportError exceptions, allowing graceful degradation when optional features are not available.

Typical usage
from .deps import HAS_FASTMCP, HAS_MCP, Client

if HAS_FASTMCP and HAS_MCP:
    # Use MCP functionality
    client = Client(...)
else:
    # Graceful fallback or error message
    client = None

FastMCP integration support.

Boolean flag indicating whether FastMCP is available.

True if 'fastmcp' package is installed and imports successfully.

FastMCP Client class for connecting to MCP servers.

None if FastMCP is not available.

Result type for MCP tool executions.

None if FastMCP is not available.

Attributes:

Name Type Description
HAS_FASTMCP
HAS_MCP
Attributes
HAS_FASTMCP module-attribute
HAS_FASTMCP = True
HAS_MCP module-attribute
HAS_MCP = True
__all__ module-attribute
__all__ = ['HAS_FASTMCP', 'HAS_MCP', 'CallToolResult', 'Client', 'ContentBlock', 'Tool']
executors

Executors for different tool providers and local functions.

Classes:

Name Description
ComposioMixin
KwargsResolverMixin
LangChainMixin
LocalExecMixin
MCPMixin

Attributes:

Name Type Description
logger
Attributes
logger module-attribute
logger = getLogger(__name__)
Classes
ComposioMixin

Attributes:

Name Type Description
composio_tools list[str]
Source code in pyagenity/graph/tool_node/executors.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
class ComposioMixin:
    _composio: ComposioAdapter | None
    composio_tools: list[str]

    async def _get_composio_tools(self) -> list[dict]:
        tools: list[dict] = []
        if not self._composio:
            return tools
        try:
            raw = self._composio.list_raw_tools_for_llm()
            for tdef in raw:
                fn = tdef.get("function", {})
                name = fn.get("name")
                if name:
                    self.composio_tools.append(name)
                tools.append(tdef)
        except Exception as e:  # pragma: no cover - network/optional
            logger.exception("Failed to fetch Composio tools: %s", e)
        return tools

    async def _composio_execute(  # noqa: PLR0915
        self,
        name: str,
        args: dict,
        tool_call_id: str,
        config: dict[str, t.Any],
        callback_mgr: CallbackManager,
    ) -> Message:
        context = CallbackContext(
            invocation_type=InvocationType.TOOL,
            node_name="ToolNode",
            function_name=name,
            metadata={
                "tool_call_id": tool_call_id,
                "args": args,
                "config": config,
                "composio": True,
            },
        )
        meta = {"function_name": name, "function_argument": args, "tool_call_id": tool_call_id}

        event = EventModel.default(
            base_config=config,
            data={
                "tool_call_id": tool_call_id,
                "args": args,
                "function_name": name,
                "is_composio": True,
            },
            content_type=[ContentType.TOOL_CALL],
            event=Event.TOOL_EXECUTION,
        )
        event.event_type = EventType.PROGRESS
        event.node_name = "ToolNode"
        event.sequence_id = 1
        publish_event(event)

        input_data = {**args}

        def safe_serialize(obj: t.Any) -> dict[str, t.Any]:
            try:
                json.dumps(obj)
                return obj if isinstance(obj, dict) else {"content": obj}
            except (TypeError, OverflowError):
                if hasattr(obj, "model_dump"):
                    dumped = obj.model_dump()  # type: ignore
                    if isinstance(dumped, dict) and dumped.get("type") == "resource":
                        resource = dumped.get("resource", {})
                        if isinstance(resource, dict) and "uri" in resource:
                            resource["uri"] = str(resource["uri"])
                            dumped["resource"] = resource
                    return dumped
                return {"content": str(obj), "type": "fallback"}

        try:
            input_data = await callback_mgr.execute_before_invoke(context, input_data)
            event.event_type = EventType.UPDATE
            event.sequence_id = 2
            event.metadata["status"] = "before_invoke_complete Invoke Composio"
            publish_event(event)

            comp_conf = (config.get("composio") if isinstance(config, dict) else None) or {}
            user_id = comp_conf.get("user_id") or config.get("user_id")
            connected_account_id = comp_conf.get("connected_account_id") or config.get(
                "connected_account_id"
            )

            if not self._composio:
                error_result = Message.tool_message(
                    content=[
                        ErrorBlock(message="Composio adapter not configured"),
                        ToolResultBlock(
                            call_id=tool_call_id,
                            output="Composio adapter not configured",
                            status="failed",
                            is_error=True,
                        ),
                    ],
                    meta=meta,
                )
                event.event_type = EventType.ERROR
                event.metadata["error"] = "Composio adapter not configured"
                publish_event(event)
                return error_result

            res = self._composio.execute(
                slug=name,
                arguments=input_data,
                user_id=user_id,
                connected_account_id=connected_account_id,
            )

            successful = bool(res.get("successful"))
            payload = res.get("data")
            error = res.get("error")

            result_blocks = []
            if error and not successful:
                result_blocks.append(
                    ToolResultBlock(
                        call_id=tool_call_id,
                        output={"success": False, "error": error},
                        status="failed",
                        is_error=True,
                    )
                )
                result_blocks.append(ErrorBlock(message=error))
            else:
                if isinstance(payload, list):
                    output = [safe_serialize(item) for item in payload]
                else:
                    output = [safe_serialize(payload)]
                result_blocks.append(
                    ToolResultBlock(
                        call_id=tool_call_id,
                        output=output,
                        status="completed" if successful else "failed",
                        is_error=not successful,
                    )
                )

            result = Message.tool_message(
                content=result_blocks,
                meta=meta,
            )

            res_msg = await callback_mgr.execute_after_invoke(context, input_data, result)
            event.event_type = EventType.END
            event.data["message"] = result.model_dump()
            event.metadata["status"] = "Composio tool execution complete"
            event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
            publish_event(event)
            return res_msg

        except Exception as e:  # pragma: no cover - error path
            recovery_result = await callback_mgr.execute_on_error(context, input_data, e)
            if isinstance(recovery_result, Message):
                event.event_type = EventType.END
                event.data["message"] = recovery_result.model_dump()
                event.metadata["status"] = "Composio tool execution complete, with recovery"
                event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
                publish_event(event)
                return recovery_result

            event.event_type = EventType.END
            event.data["error"] = str(e)
            event.metadata["status"] = "Composio tool execution complete, with error"
            event.content_type = [ContentType.TOOL_RESULT, ContentType.ERROR]
            publish_event(event)
            return Message.tool_message(
                content=[
                    ToolResultBlock(
                        call_id=tool_call_id,
                        output=f"Composio execution error: {e}",
                        status="failed",
                        is_error=True,
                    ),
                    ErrorBlock(message=f"Composio execution error: {e}"),
                ],
                meta=meta,
            )
Attributes
composio_tools instance-attribute
composio_tools
KwargsResolverMixin
Source code in pyagenity/graph/tool_node/executors.py
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
class KwargsResolverMixin:
    def _should_skip_parameter(self, param: inspect.Parameter) -> bool:
        return param.kind in (
            inspect.Parameter.VAR_POSITIONAL,
            inspect.Parameter.VAR_KEYWORD,
        )

    def _handle_injectable_parameter(
        self,
        p_name: str,
        param: inspect.Parameter,
        injectable_params: dict,
        dependency_container,
    ) -> t.Any | None:
        if p_name in injectable_params:
            injectable_value = injectable_params[p_name]
            if injectable_value is not None:
                return injectable_value

        if dependency_container and dependency_container.has(p_name):
            return dependency_container.get(p_name)

        if param.default is inspect._empty:
            raise TypeError(f"Required injectable parameter '{p_name}' not found")

        return None

    def _get_parameter_value(
        self,
        p_name: str,
        param: inspect.Parameter,
        args: dict,
        injectable_params: dict,
        dependency_container,
    ) -> t.Any | None:
        if p_name in injectable_params:
            return self._handle_injectable_parameter(
                p_name, param, injectable_params, dependency_container
            )

        value_sources = [
            lambda: args.get(p_name),
            lambda: (
                dependency_container.get(p_name)
                if dependency_container and dependency_container.has(p_name)
                else None
            ),
        ]

        for source in value_sources:
            value = source()
            if value is not None:
                return value

        if param.default is not inspect._empty:
            return None

        raise TypeError(f"Missing required parameter '{p_name}' for function")

    def _prepare_kwargs(
        self,
        sig: inspect.Signature,
        args: dict,
        injectable_params: dict,
        dependency_container,
    ) -> dict:
        kwargs: dict = {}
        for p_name, p in sig.parameters.items():
            if self._should_skip_parameter(p):
                continue
            value = self._get_parameter_value(
                p_name, p, args, injectable_params, dependency_container
            )
            if value is not None:
                kwargs[p_name] = value
        return kwargs
LangChainMixin

Attributes:

Name Type Description
langchain_tools list[str]
Source code in pyagenity/graph/tool_node/executors.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
class LangChainMixin:
    _langchain: t.Any | None
    langchain_tools: list[str]

    async def _get_langchain_tools(self) -> list[dict]:
        tools: list[dict] = []
        if not self._langchain:
            return tools
        try:
            raw = self._langchain.list_tools_for_llm()
            for tdef in raw:
                fn = tdef.get("function", {})
                name = fn.get("name")
                if name:
                    self.langchain_tools.append(name)
                tools.append(tdef)
        except Exception as e:  # pragma: no cover - optional
            logger.warning("Failed to fetch LangChain tools: %s", e)
        return tools

    async def _langchain_execute(  # noqa: PLR0915
        self,
        name: str,
        args: dict,
        tool_call_id: str,
        config: dict[str, t.Any],
        callback_mgr: CallbackManager,
    ) -> Message:
        context = CallbackContext(
            invocation_type=InvocationType.TOOL,
            node_name="ToolNode",
            function_name=name,
            metadata={
                "tool_call_id": tool_call_id,
                "args": args,
                "config": config,
                "langchain": True,
            },
        )
        meta = {"function_name": name, "function_argument": args, "tool_call_id": tool_call_id}

        event = EventModel.default(
            base_config=config,
            data={
                "tool_call_id": tool_call_id,
                "args": args,
                "function_name": name,
                "is_langchain": True,
            },
            content_type=[ContentType.TOOL_CALL],
            event=Event.TOOL_EXECUTION,
        )
        event.event_type = EventType.PROGRESS
        event.node_name = "ToolNode"
        event.sequence_id = 1
        publish_event(event)

        input_data = {**args}

        def safe_serialize(obj: t.Any) -> dict[str, t.Any]:
            try:
                json.dumps(obj)
                return obj if isinstance(obj, dict) else {"content": obj}
            except (TypeError, OverflowError):
                if hasattr(obj, "model_dump"):
                    dumped = obj.model_dump()  # type: ignore
                    if isinstance(dumped, dict) and dumped.get("type") == "resource":
                        resource = dumped.get("resource", {})
                        if isinstance(resource, dict) and "uri" in resource:
                            resource["uri"] = str(resource["uri"])
                            dumped["resource"] = resource
                    return dumped
                return {"content": str(obj), "type": "fallback"}

        try:
            input_data = await callback_mgr.execute_before_invoke(context, input_data)
            event.event_type = EventType.UPDATE
            event.sequence_id = 2
            event.metadata["status"] = "before_invoke_complete Invoke LangChain"
            publish_event(event)

            if not self._langchain:
                error_result = Message.tool_message(
                    content=[
                        ErrorBlock(message="LangChain adapter not configured"),
                        ToolResultBlock(
                            call_id=tool_call_id,
                            output="LangChain adapter not configured",
                            status="failed",
                            is_error=True,
                        ),
                    ],
                    meta=meta,
                )
                event.event_type = EventType.ERROR
                event.metadata["error"] = "LangChain adapter not configured"
                publish_event(event)
                return error_result

            res = self._langchain.execute(name=name, arguments=input_data)
            successful = bool(res.get("successful"))
            payload = res.get("data")
            error = res.get("error")

            result_blocks = []
            if error and not successful:
                result_blocks.append(
                    ToolResultBlock(
                        call_id=tool_call_id,
                        output={"success": False, "error": error},
                        status="failed",
                        is_error=True,
                    )
                )
                result_blocks.append(ErrorBlock(message=error))
            else:
                if isinstance(payload, list):
                    output = [safe_serialize(item) for item in payload]
                else:
                    output = [safe_serialize(payload)]
                result_blocks.append(
                    ToolResultBlock(
                        call_id=tool_call_id,
                        output=output,
                        status="completed" if successful else "failed",
                        is_error=not successful,
                    )
                )

            result = Message.tool_message(
                content=result_blocks,
                meta=meta,
            )

            res_msg = await callback_mgr.execute_after_invoke(context, input_data, result)
            event.event_type = EventType.END
            event.data["message"] = result.model_dump()
            event.metadata["status"] = "LangChain tool execution complete"
            event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
            publish_event(event)
            return res_msg

        except Exception as e:  # pragma: no cover - error path
            recovery_result = await callback_mgr.execute_on_error(context, input_data, e)
            if isinstance(recovery_result, Message):
                event.event_type = EventType.END
                event.data["message"] = recovery_result.model_dump()
                event.metadata["status"] = "LangChain tool execution complete, with recovery"
                event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
                publish_event(event)
                return recovery_result

            event.event_type = EventType.END
            event.data["error"] = str(e)
            event.metadata["status"] = "LangChain tool execution complete, with error"
            event.content_type = [ContentType.TOOL_RESULT, ContentType.ERROR]
            publish_event(event)
            return Message.tool_message(
                content=[
                    ToolResultBlock(
                        call_id=tool_call_id,
                        output=f"LangChain execution error: {e}",
                        status="failed",
                        is_error=True,
                    ),
                    ErrorBlock(message=f"LangChain execution error: {e}"),
                ],
                meta=meta,
            )
Attributes
langchain_tools instance-attribute
langchain_tools
LocalExecMixin
Source code in pyagenity/graph/tool_node/executors.py
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
class LocalExecMixin:
    _funcs: dict[str, t.Callable]

    def _prepare_input_data_tool(
        self,
        fn: t.Callable,
        name: str,
        args: dict,
        default_data: dict,
    ) -> dict:
        sig = inspect.signature(fn)
        input_data = {}
        for param_name, param in sig.parameters.items():
            if param.kind in (
                inspect.Parameter.VAR_POSITIONAL,
                inspect.Parameter.VAR_KEYWORD,
            ):
                continue

            if param_name in ["state", "config", "tool_call_id"]:
                input_data[param_name] = default_data[param_name]
                continue

            if param_name in INJECTABLE_PARAMS:
                continue

            if (
                hasattr(param, "default")
                and param.default is not inspect._empty
                and hasattr(param.default, "__class__")
            ):
                try:
                    if "Inject" in str(type(param.default)):
                        logger.debug(
                            "Skipping injectable parameter '%s' with Inject syntax",
                            param_name,
                        )
                        continue
                except Exception as exc:  # pragma: no cover - defensive
                    logger.exception("Inject detection failed for '%s': %s", param_name, exc)

            if param_name in args:
                input_data[param_name] = args[param_name]
            elif param.default is inspect.Parameter.empty:
                raise TypeError(f"Missing required parameter '{param_name}' for function '{name}'")

        return input_data

    async def _internal_execute(  # noqa: PLR0915
        self,
        name: str,
        args: dict,
        tool_call_id: str,
        config: dict[str, t.Any],
        state: AgentState,
        callback_mgr: CallbackManager,
    ) -> Message:
        context = CallbackContext(
            invocation_type=InvocationType.TOOL,
            node_name="ToolNode",
            function_name=name,
            metadata={"tool_call_id": tool_call_id, "args": args, "config": config},
        )

        fn = self._funcs[name]
        input_data = self._prepare_input_data_tool(
            fn,
            name,
            args,
            {
                "tool_call_id": tool_call_id,
                "state": state,
                "config": config,
            },
        )

        meta = {
            "function_name": name,
            "function_argument": args,
            "tool_call_id": tool_call_id,
        }

        event = EventModel.default(
            base_config=config,
            data={
                "tool_call_id": tool_call_id,
                "args": args,
                "function_name": name,
                "is_mcp": False,
            },
            content_type=[ContentType.TOOL_CALL],
            event=Event.TOOL_EXECUTION,
        )
        event.event_type = EventType.PROGRESS
        event.node_name = "ToolNode"
        event.sequence_id = 1
        publish_event(event)

        def safe_serialize(obj: t.Any) -> dict[str, t.Any]:
            try:
                json.dumps(obj)
                return obj if isinstance(obj, dict) else {"content": obj}
            except (TypeError, OverflowError):
                if hasattr(obj, "model_dump"):
                    dumped = obj.model_dump()  # type: ignore
                    if isinstance(dumped, dict) and dumped.get("type") == "resource":
                        resource = dumped.get("resource", {})
                        if isinstance(resource, dict) and "uri" in resource:
                            resource["uri"] = str(resource["uri"])
                            dumped["resource"] = resource
                    return dumped
                return {"content": str(obj), "type": "fallback"}

        try:
            input_data = await callback_mgr.execute_before_invoke(context, input_data)

            event.event_type = EventType.UPDATE
            event.sequence_id = 2
            event.metadata["status"] = "before_invoke_complete Invoke internal"
            publish_event(event)

            result = await call_sync_or_async(fn, **input_data)

            result = await callback_mgr.execute_after_invoke(
                context,
                input_data,
                result,
            )

            if isinstance(result, Message):
                meta_data = result.metadata or {}
                meta.update(meta_data)
                result.metadata = meta

                event.event_type = EventType.END
                event.data["message"] = result.model_dump()
                event.metadata["status"] = "Internal tool execution complete"
                event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
                publish_event(event)
                return result

            result_blocks = []
            if isinstance(result, str):
                result_blocks.append(
                    ToolResultBlock(
                        call_id=tool_call_id,
                        output=result,
                        status="completed",
                        is_error=False,
                    )
                )
            elif isinstance(result, dict):
                result_blocks.append(
                    ToolResultBlock(
                        call_id=tool_call_id,
                        output=[safe_serialize(result)],
                        status="completed",
                        is_error=False,
                    )
                )
            elif hasattr(result, "model_dump"):
                result_blocks.append(
                    ToolResultBlock(
                        call_id=tool_call_id,
                        output=[safe_serialize(result.model_dump())],
                        status="completed",
                        is_error=False,
                    )
                )
            elif hasattr(result, "__dict__"):
                result_blocks.append(
                    ToolResultBlock(
                        call_id=tool_call_id,
                        output=[safe_serialize(result.__dict__)],
                        status="completed",
                        is_error=False,
                    )
                )
            elif isinstance(result, list):
                output = [safe_serialize(item) for item in result]
                result_blocks.append(
                    ToolResultBlock(
                        call_id=tool_call_id,
                        output=output,
                        status="completed",
                        is_error=False,
                    )
                )
            else:
                result_blocks.append(
                    ToolResultBlock(
                        call_id=tool_call_id,
                        output=str(result),
                        status="completed",
                        is_error=False,
                    )
                )

            msg = Message.tool_message(
                content=result_blocks,
                meta=meta,
            )

            event.event_type = EventType.END
            event.data["message"] = msg.model_dump()
            event.metadata["status"] = "Internal tool execution complete"
            event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
            publish_event(event)

            return msg

        except Exception as e:  # pragma: no cover - error path
            recovery_result = await callback_mgr.execute_on_error(context, input_data, e)

            if isinstance(recovery_result, Message):
                event.event_type = EventType.END
                event.data["message"] = recovery_result.model_dump()
                event.metadata["status"] = "Internal tool execution complete, with recovery"
                event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
                publish_event(event)
                return recovery_result

            event.event_type = EventType.END
            event.data["error"] = str(e)
            event.metadata["status"] = "Internal tool execution complete, with error"
            event.content_type = [ContentType.TOOL_RESULT, ContentType.ERROR]
            publish_event(event)

            return Message.tool_message(
                content=[
                    ToolResultBlock(
                        call_id=tool_call_id,
                        output=f"Internal execution error: {e}",
                        status="failed",
                        is_error=True,
                    ),
                    ErrorBlock(message=f"Internal execution error: {e}"),
                ],
                meta=meta,
            )
MCPMixin

Attributes:

Name Type Description
mcp_tools list[str]
Source code in pyagenity/graph/tool_node/executors.py
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
class MCPMixin:
    _client: t.Any | None
    # The concrete ToolNode defines this
    mcp_tools: list[str]  # type: ignore[assignment]

    def _serialize_result(
        self,
        tool_call_id: str,
        res: t.Any,
    ) -> list[ContentBlock]:
        def safe_serialize(obj: t.Any) -> dict[str, t.Any]:
            try:
                json.dumps(obj)
                return obj if isinstance(obj, dict) else {"content": obj}
            except (TypeError, OverflowError):
                if hasattr(obj, "model_dump"):
                    dumped = obj.model_dump()  # type: ignore
                    if isinstance(dumped, dict) and dumped.get("type") == "resource":
                        resource = dumped.get("resource", {})
                        if isinstance(resource, dict) and "uri" in resource:
                            resource["uri"] = str(resource["uri"])
                            dumped["resource"] = resource
                    return dumped
                return {"content": str(obj), "type": "fallback"}

        for source in [
            getattr(res, "content", None),
            getattr(res, "structured_content", None),
            getattr(res, "data", None),
        ]:
            if source is None:
                continue
            try:
                if isinstance(source, list):
                    result = [safe_serialize(item) for item in source]
                else:
                    result = [safe_serialize(source)]

                return [
                    ToolResultBlock(
                        call_id=tool_call_id,
                        output=result,
                        is_error=False,
                        status="completed",
                    )
                ]
            except Exception as e:  # pragma: no cover - defensive
                logger.exception("Serialization failure: %s", e)
                continue

        return [
            ToolResultBlock(
                call_id=tool_call_id,
                output=[
                    {
                        "content": str(res),
                        "type": "fallback",
                    }
                ],
                is_error=False,
                status="completed",
            )
        ]

    async def _get_mcp_tool(self) -> list[dict]:
        tools: list[dict] = []
        if self._client:
            async with self._client:
                res = await self._client.ping()
                if not res:
                    return tools
                mcp_tools: list[t.Any] = await self._client.list_tools()
                for i in mcp_tools:
                    # attribute provided by concrete ToolNode
                    self.mcp_tools.append(i.name)  # type: ignore[attr-defined]
                    tools.append(
                        {
                            "type": "function",
                            "function": {
                                "name": i.name,
                                "description": i.description,
                                "parameters": i.inputSchema,
                            },
                        }
                    )
        return tools

    async def _mcp_execute(
        self,
        name: str,
        args: dict,
        tool_call_id: str,
        config: dict[str, t.Any],
        callback_mgr: CallbackManager,
    ) -> Message:
        context = CallbackContext(
            invocation_type=InvocationType.MCP,
            node_name="ToolNode",
            function_name=name,
            metadata={
                "tool_call_id": tool_call_id,
                "args": args,
                "config": config,
                "mcp_client": bool(self._client),
            },
        )

        meta = {
            "function_name": name,
            "function_argument": args,
            "tool_call_id": tool_call_id,
        }

        event = EventModel.default(
            base_config=config,
            data={
                "tool_call_id": tool_call_id,
                "args": args,
                "function_name": name,
                "is_mcp": True,
            },
            content_type=[ContentType.TOOL_CALL],
            event=Event.TOOL_EXECUTION,
        )
        event.event_type = EventType.PROGRESS
        event.node_name = "ToolNode"
        event.sequence_id = 1
        publish_event(event)

        input_data = {**args}

        try:
            input_data = await callback_mgr.execute_before_invoke(context, input_data)
            event.event_type = EventType.UPDATE
            event.sequence_id = 2
            event.metadata["status"] = "before_invoke_complete Invoke MCP"
            publish_event(event)

            if not self._client:
                error_result = Message.tool_message(
                    content=[
                        ErrorBlock(
                            message="No MCP client configured",
                        ),
                        ToolResultBlock(
                            call_id=tool_call_id,
                            output="No MCP client configured",
                            is_error=True,
                            status="failed",
                        ),
                    ],
                    meta=meta,
                )
                res = await callback_mgr.execute_after_invoke(context, input_data, error_result)
                event.event_type = EventType.ERROR
                event.metadata["error"] = "No MCP client configured"
                publish_event(event)
                return res

            async with self._client:
                if not await self._client.ping():
                    error_result = Message.tool_message(
                        content=[
                            ErrorBlock(message="MCP Server not available. Ping failed."),
                            ToolResultBlock(
                                call_id=tool_call_id,
                                output="MCP Server not available. Ping failed.",
                                is_error=True,
                                status="failed",
                            ),
                        ],
                        meta=meta,
                    )
                    event.event_type = EventType.ERROR
                    event.metadata["error"] = "MCP server not available, ping failed"
                    publish_event(event)
                    return await callback_mgr.execute_after_invoke(
                        context, input_data, error_result
                    )

                res: t.Any = await self._client.call_tool(name, input_data)

                final_res = self._serialize_result(tool_call_id, res)

                result = Message.tool_message(
                    content=final_res,
                    meta=meta,
                )

                res = await callback_mgr.execute_after_invoke(context, input_data, result)
                event.event_type = EventType.END
                event.data["message"] = result.model_dump()
                event.metadata["status"] = "MCP tool execution complete"
                event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
                publish_event(event)
                return res

        except Exception as e:  # pragma: no cover - error path
            recovery_result = await callback_mgr.execute_on_error(context, input_data, e)

            if isinstance(recovery_result, Message):
                event.event_type = EventType.END
                event.data["message"] = recovery_result.model_dump()
                event.metadata["status"] = "MCP tool execution complete, with recovery"
                event.content_type = [ContentType.TOOL_RESULT, ContentType.MESSAGE]
                publish_event(event)
                return recovery_result

            event.event_type = EventType.END
            event.data["error"] = str(e)
            event.metadata["status"] = "MCP tool execution complete, with recovery"
            event.content_type = [ContentType.TOOL_RESULT, ContentType.ERROR]
            publish_event(event)

            return Message.tool_message(
                content=[
                    ToolResultBlock(
                        call_id=tool_call_id,
                        output=f"MCP execution error: {e}",
                        is_error=True,
                        status="failed",
                    ),
                    ErrorBlock(message=f"MCP execution error: {e}"),
                ],
                meta=meta,
            )
Attributes
mcp_tools instance-attribute
mcp_tools
Functions
schema

Schema utilities and local tool description building for ToolNode.

This module provides the SchemaMixin class which handles automatic schema generation for local Python functions, converting their type annotations and signatures into OpenAI-compatible function schemas. It supports various Python types including primitives, Optional types, List types, and Literal enums.

The schema generation process inspects function signatures and converts them to JSON Schema format suitable for use with language models and function calling APIs.

Classes:

Name Description
SchemaMixin

Mixin providing schema generation and local tool description building.

Attributes Classes
SchemaMixin

Mixin providing schema generation and local tool description building.

This mixin provides functionality to automatically generate JSON Schema definitions from Python function signatures. It handles type annotation conversion, parameter analysis, and OpenAI-compatible function schema generation for local tools.

The mixin is designed to be used with ToolNode to automatically generate tool schemas without requiring manual schema definition for Python functions.

Attributes:

Name Type Description
_funcs dict[str, Callable]

Dictionary mapping function names to callable functions. This attribute is expected to be provided by the mixing class.

Methods:

Name Description
get_local_tool

Generate OpenAI-compatible tool definitions for all registered local functions.

Source code in pyagenity/graph/tool_node/schema.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
class SchemaMixin:
    """Mixin providing schema generation and local tool description building.

    This mixin provides functionality to automatically generate JSON Schema definitions
    from Python function signatures. It handles type annotation conversion, parameter
    analysis, and OpenAI-compatible function schema generation for local tools.

    The mixin is designed to be used with ToolNode to automatically generate tool
    schemas without requiring manual schema definition for Python functions.

    Attributes:
        _funcs: Dictionary mapping function names to callable functions. This
            attribute is expected to be provided by the mixing class.
    """

    _funcs: dict[str, t.Callable]

    @staticmethod
    def _handle_optional_annotation(annotation: t.Any, default: t.Any) -> dict | None:
        """Handle Optional type annotations and convert them to appropriate schemas.

        Processes Optional[T] type annotations (Union[T, None]) and generates
        schema for the non-None type. This method handles the common pattern
        of optional parameters in function signatures.

        Args:
            annotation: The type annotation to process, potentially an Optional type.
            default: The default value for the parameter, used for schema generation.

        Returns:
            Dictionary containing the JSON schema for the non-None type if the
            annotation is Optional, None otherwise.

        Example:
            Optional[str] -> {"type": "string"}
            Optional[int] -> {"type": "integer"}
        """
        args = getattr(annotation, "__args__", None)
        if args and any(a is type(None) for a in args):
            non_none = [a for a in args if a is not type(None)]
            if non_none:
                return SchemaMixin._annotation_to_schema(non_none[0], default)
        return None

    @staticmethod
    def _handle_complex_annotation(annotation: t.Any) -> dict:
        """Handle complex type annotations like List, Literal, and generic types.

        Processes generic type annotations that aren't simple primitive types,
        including container types like List and special types like Literal enums.
        Falls back to string type for unrecognized complex types.

        Args:
            annotation: The complex type annotation to process (e.g., List[str],
                Literal["a", "b", "c"]).

        Returns:
            Dictionary containing the appropriate JSON schema for the complex type.
            For List types, returns array schema with item type.
            For Literal types, returns enum schema with allowed values.
            For unknown types, returns string type as fallback.

        Example:
            List[str] -> {"type": "array", "items": {"type": "string"}}
            Literal["red", "green"] -> {"type": "string", "enum": ["red", "green"]}
        """
        origin = getattr(annotation, "__origin__", None)
        if origin is list:
            item_type = getattr(annotation, "__args__", (str,))[0]
            item_schema = SchemaMixin._annotation_to_schema(item_type, None)
            return {"type": "array", "items": item_schema}

        Literal = getattr(t, "Literal", None)
        if Literal is not None and origin is Literal:
            literals = list(getattr(annotation, "__args__", ()))
            if all(isinstance(literal, str) for literal in literals):
                return {"type": "string", "enum": literals}
            return {"enum": literals}

        return {"type": "string"}

    @staticmethod
    def _annotation_to_schema(annotation: t.Any, default: t.Any) -> dict:
        """Convert a Python type annotation to JSON Schema format.

        Main entry point for type annotation conversion. Handles both simple
        and complex types by delegating to appropriate helper methods.
        Includes default value handling when present.

        Args:
            annotation: The Python type annotation to convert (e.g., str, int,
                Optional[str], List[int]).
            default: The default value for the parameter, included in schema
                if not inspect._empty.

        Returns:
            Dictionary containing the JSON schema representation of the type
            annotation, including default values where applicable.

        Example:
            str -> {"type": "string"}
            int -> {"type": "integer"}
            str with default "hello" -> {"type": "string", "default": "hello"}
        """
        schema = SchemaMixin._handle_optional_annotation(annotation, default)
        if schema:
            return schema

        primitive_mappings = {
            str: {"type": "string"},
            int: {"type": "integer"},
            float: {"type": "number"},
            bool: {"type": "boolean"},
        }

        if annotation in primitive_mappings:
            schema = primitive_mappings[annotation]
        else:
            schema = SchemaMixin._handle_complex_annotation(annotation)

        if default is not inspect._empty:
            schema["default"] = default

        return schema

    def get_local_tool(self) -> list[dict]:
        """Generate OpenAI-compatible tool definitions for all registered local functions.

        Inspects all registered functions in _funcs and automatically generates
        tool schemas by analyzing function signatures, type annotations, and docstrings.
        Excludes injectable parameters that are provided by the framework.

        Returns:
            List of tool definitions in OpenAI function calling format. Each
            definition includes the function name, description (from docstring),
            and complete parameter schema with types and required fields.

        Example:
            For a function:
            ```python
            def calculate(a: int, b: int, operation: str = "add") -> int:
                '''Perform arithmetic calculation.'''
                return a + b if operation == "add" else a - b
            ```

            Returns:
            ```python
            [
                {
                    "type": "function",
                    "function": {
                        "name": "calculate",
                        "description": "Perform arithmetic calculation.",
                        "parameters": {
                            "type": "object",
                            "properties": {
                                "a": {"type": "integer"},
                                "b": {"type": "integer"},
                                "operation": {"type": "string", "default": "add"},
                            },
                            "required": ["a", "b"],
                        },
                    },
                }
            ]
            ```

        Note:
            Parameters listed in INJECTABLE_PARAMS (like 'state', 'config',
            'tool_call_id') are automatically excluded from the generated schema
            as they are provided by the framework during execution.
        """
        tools: list[dict] = []
        for name, fn in self._funcs.items():
            sig = inspect.signature(fn)
            params_schema: dict = {"type": "object", "properties": {}, "required": []}

            for p_name, p in sig.parameters.items():
                if p.kind in (
                    inspect.Parameter.VAR_POSITIONAL,
                    inspect.Parameter.VAR_KEYWORD,
                ):
                    continue

                if p_name in INJECTABLE_PARAMS:
                    continue

                annotation = p.annotation if p.annotation is not inspect._empty else str
                prop = SchemaMixin._annotation_to_schema(annotation, p.default)
                params_schema["properties"][p_name] = prop

                if p.default is inspect._empty:
                    params_schema["required"].append(p_name)

            if not params_schema["required"]:
                params_schema.pop("required")

            description = inspect.getdoc(fn) or "No description provided."

            # provider = getattr(fn, "_py_tool_provider", None)
            # tags = getattr(fn, "_py_tool_tags", None)
            # capabilities = getattr(fn, "_py_tool_capabilities", None)

            entry = {
                "type": "function",
                "function": {
                    "name": name,
                    "description": description,
                    "parameters": params_schema,
                },
            }
            # meta: dict[str, t.Any] = {}
            # if provider:
            #     meta["provider"] = provider
            # if tags:
            #     meta["tags"] = tags
            # if capabilities:
            #     meta["capabilities"] = capabilities
            # if meta:
            #     entry["x-pyagenity"] = meta

            tools.append(entry)

        return tools
Functions
get_local_tool
get_local_tool()

Generate OpenAI-compatible tool definitions for all registered local functions.

Inspects all registered functions in _funcs and automatically generates tool schemas by analyzing function signatures, type annotations, and docstrings. Excludes injectable parameters that are provided by the framework.

Returns:

Type Description
list[dict]

List of tool definitions in OpenAI function calling format. Each

list[dict]

definition includes the function name, description (from docstring),

list[dict]

and complete parameter schema with types and required fields.

Example

For a function:

def calculate(a: int, b: int, operation: str = "add") -> int:
    '''Perform arithmetic calculation.'''
    return a + b if operation == "add" else a - b

Returns:

[
    {
        "type": "function",
        "function": {
            "name": "calculate",
            "description": "Perform arithmetic calculation.",
            "parameters": {
                "type": "object",
                "properties": {
                    "a": {"type": "integer"},
                    "b": {"type": "integer"},
                    "operation": {"type": "string", "default": "add"},
                },
                "required": ["a", "b"],
            },
        },
    }
]

Note

Parameters listed in INJECTABLE_PARAMS (like 'state', 'config', 'tool_call_id') are automatically excluded from the generated schema as they are provided by the framework during execution.

Source code in pyagenity/graph/tool_node/schema.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def get_local_tool(self) -> list[dict]:
    """Generate OpenAI-compatible tool definitions for all registered local functions.

    Inspects all registered functions in _funcs and automatically generates
    tool schemas by analyzing function signatures, type annotations, and docstrings.
    Excludes injectable parameters that are provided by the framework.

    Returns:
        List of tool definitions in OpenAI function calling format. Each
        definition includes the function name, description (from docstring),
        and complete parameter schema with types and required fields.

    Example:
        For a function:
        ```python
        def calculate(a: int, b: int, operation: str = "add") -> int:
            '''Perform arithmetic calculation.'''
            return a + b if operation == "add" else a - b
        ```

        Returns:
        ```python
        [
            {
                "type": "function",
                "function": {
                    "name": "calculate",
                    "description": "Perform arithmetic calculation.",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "a": {"type": "integer"},
                            "b": {"type": "integer"},
                            "operation": {"type": "string", "default": "add"},
                        },
                        "required": ["a", "b"],
                    },
                },
            }
        ]
        ```

    Note:
        Parameters listed in INJECTABLE_PARAMS (like 'state', 'config',
        'tool_call_id') are automatically excluded from the generated schema
        as they are provided by the framework during execution.
    """
    tools: list[dict] = []
    for name, fn in self._funcs.items():
        sig = inspect.signature(fn)
        params_schema: dict = {"type": "object", "properties": {}, "required": []}

        for p_name, p in sig.parameters.items():
            if p.kind in (
                inspect.Parameter.VAR_POSITIONAL,
                inspect.Parameter.VAR_KEYWORD,
            ):
                continue

            if p_name in INJECTABLE_PARAMS:
                continue

            annotation = p.annotation if p.annotation is not inspect._empty else str
            prop = SchemaMixin._annotation_to_schema(annotation, p.default)
            params_schema["properties"][p_name] = prop

            if p.default is inspect._empty:
                params_schema["required"].append(p_name)

        if not params_schema["required"]:
            params_schema.pop("required")

        description = inspect.getdoc(fn) or "No description provided."

        # provider = getattr(fn, "_py_tool_provider", None)
        # tags = getattr(fn, "_py_tool_tags", None)
        # capabilities = getattr(fn, "_py_tool_capabilities", None)

        entry = {
            "type": "function",
            "function": {
                "name": name,
                "description": description,
                "parameters": params_schema,
            },
        }
        # meta: dict[str, t.Any] = {}
        # if provider:
        #     meta["provider"] = provider
        # if tags:
        #     meta["tags"] = tags
        # if capabilities:
        #     meta["capabilities"] = capabilities
        # if meta:
        #     entry["x-pyagenity"] = meta

        tools.append(entry)

    return tools
utils

Modules:

Name Description
handler_mixins

Shared mixins for graph and node handler classes.

invoke_handler
invoke_node_handler

InvokeNodeHandler utilities for PyAgenity agent graph execution.

stream_handler

Streaming graph execution handler for PyAgenity workflows.

stream_node_handler

Streaming node handler for PyAgenity graph workflows.

stream_utils

Streaming utility functions for PyAgenity graph workflows.

utils

Core utility functions for graph execution and state management.

Modules
handler_mixins

Shared mixins for graph and node handler classes.

This module provides lightweight mixins that add common functionality to handler classes without changing their core runtime behavior. The mixins follow the composition pattern to keep responsibilities explicit and allow handlers to inherit only the capabilities they need.

The mixins provide structured logging, configuration management, and other cross-cutting concerns that are commonly needed across different handler types. By using mixins, the core handler logic remains focused while gaining these shared capabilities.

Typical usage
class MyHandler(BaseLoggingMixin, InterruptConfigMixin):
    def __init__(self):
        self._set_interrupts(["node1"], ["node2"])
        self._log_start("Handler initialized")

Classes:

Name Description
BaseLoggingMixin

Provides structured logging helpers for handler classes.

InterruptConfigMixin

Manages interrupt configuration for graph-level execution handlers.

Classes
BaseLoggingMixin

Provides structured logging helpers for handler classes.

This mixin adds consistent logging capabilities to handler classes without requiring them to manage logger instances directly. It automatically creates loggers based on the module name and provides convenience methods for common logging operations.

The mixin is designed to be lightweight and non-intrusive, adding only logging functionality without affecting the core behavior of the handler.

Attributes:

Name Type Description
_logger Logger

Cached logger instance for the handler class.

Example
class MyHandler(BaseLoggingMixin):
    def process(self):
        self._log_start("Processing started")
        try:
            # Do work
            self._log_debug("Work completed successfully")
        except Exception as e:
            self._log_error("Processing failed: %s", e)
Source code in pyagenity/graph/utils/handler_mixins.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class BaseLoggingMixin:
    """Provides structured logging helpers for handler classes.

    This mixin adds consistent logging capabilities to handler classes without
    requiring them to manage logger instances directly. It automatically creates
    loggers based on the module name and provides convenience methods for
    common logging operations.

    The mixin is designed to be lightweight and non-intrusive, adding only
    logging functionality without affecting the core behavior of the handler.

    Attributes:
        _logger: Cached logger instance for the handler class.

    Example:
        ```python
        class MyHandler(BaseLoggingMixin):
            def process(self):
                self._log_start("Processing started")
                try:
                    # Do work
                    self._log_debug("Work completed successfully")
                except Exception as e:
                    self._log_error("Processing failed: %s", e)
        ```
    """

    @property
    def _logger(self) -> logging.Logger:
        """Get or create a logger instance for this handler.

        Creates a logger using the handler's module name, providing consistent
        logging across different handler instances while maintaining proper
        logger hierarchy and configuration.

        Returns:
            Logger instance configured for this handler's module.
        """
        return logging.getLogger(getattr(self, "__module__", __name__))

    def _log_start(self, msg: str, *args: Any) -> None:
        """Log an informational message for process start/initialization.

        Args:
            msg: Log message format string.
            *args: Arguments for message formatting.
        """
        self._logger.info(msg, *args)

    def _log_debug(self, msg: str, *args: Any) -> None:
        """Log a debug message for detailed execution information.

        Args:
            msg: Log message format string.
            *args: Arguments for message formatting.
        """
        self._logger.debug(msg, *args)

    def _log_error(self, msg: str, *args: Any) -> None:
        """Log an error message for exceptional conditions.

        Args:
            msg: Log message format string.
            *args: Arguments for message formatting.
        """
        self._logger.error(msg, *args)
InterruptConfigMixin

Manages interrupt configuration for graph-level execution handlers.

This mixin provides functionality to store and manage interrupt points configuration for graph execution. Interrupts allow graph execution to be paused before or after specific nodes for debugging, human intervention, or checkpoint creation.

The mixin maintains separate lists for "before" and "after" interrupts, allowing fine-grained control over when graph execution should pause.

Attributes:

Name Type Description
interrupt_before list[str] | None

List of node names where execution should pause before node execution begins.

interrupt_after list[str] | None

List of node names where execution should pause after node execution completes.

Example
class GraphHandler(InterruptConfigMixin):
    def __init__(self):
        self._set_interrupts(
            interrupt_before=["approval_node"], interrupt_after=["data_processing"]
        )
Source code in pyagenity/graph/utils/handler_mixins.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
class InterruptConfigMixin:
    """Manages interrupt configuration for graph-level execution handlers.

    This mixin provides functionality to store and manage interrupt points
    configuration for graph execution. Interrupts allow graph execution to be
    paused before or after specific nodes for debugging, human intervention,
    or checkpoint creation.

    The mixin maintains separate lists for "before" and "after" interrupts,
    allowing fine-grained control over when graph execution should pause.

    Attributes:
        interrupt_before: List of node names where execution should pause
            before node execution begins.
        interrupt_after: List of node names where execution should pause
            after node execution completes.

    Example:
        ```python
        class GraphHandler(InterruptConfigMixin):
            def __init__(self):
                self._set_interrupts(
                    interrupt_before=["approval_node"], interrupt_after=["data_processing"]
                )
        ```
    """

    interrupt_before: list[str] | None
    interrupt_after: list[str] | None

    def _set_interrupts(
        self,
        interrupt_before: list[str] | None,
        interrupt_after: list[str] | None,
    ) -> None:
        """Configure interrupt points for graph execution control.

        Sets up the interrupt configuration for this handler, defining which
        nodes should trigger execution pauses. This method normalizes None
        values to empty lists for consistent handling.

        Args:
            interrupt_before: List of node names where execution should be
                interrupted before the node begins execution. Pass None to
                disable before-interrupts.
            interrupt_after: List of node names where execution should be
                interrupted after the node completes execution. Pass None to
                disable after-interrupts.

        Note:
            This method should be called during handler initialization to
            establish the interrupt configuration before graph execution begins.
        """
        self.interrupt_before = interrupt_before or []
        self.interrupt_after = interrupt_after or []
Attributes
interrupt_after instance-attribute
interrupt_after
interrupt_before instance-attribute
interrupt_before
invoke_handler

Classes:

Name Description
InvokeHandler

Attributes:

Name Type Description
StateT
logger
Attributes
StateT module-attribute
StateT = TypeVar('StateT', bound=AgentState)
logger module-attribute
logger = getLogger(__name__)
Classes
InvokeHandler

Bases: BaseLoggingMixin, InterruptConfigMixin

Methods:

Name Description
__init__
invoke

Execute the graph asynchronously with event publishing.

Attributes:

Name Type Description
edges list[Edge]
interrupt_after
interrupt_before
nodes dict[str, Node]
Source code in pyagenity/graph/utils/invoke_handler.py
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
class InvokeHandler[StateT: AgentState](
    BaseLoggingMixin,
    InterruptConfigMixin,
):
    @inject
    def __init__(
        self,
        nodes: dict[str, Node],
        edges: list[Edge],
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
    ):
        self.nodes: dict[str, Node] = nodes
        self.edges: list[Edge] = edges
        # Keep existing attributes for backward-compatibility
        self.interrupt_before = interrupt_before or []
        self.interrupt_after = interrupt_after or []
        # And set via mixin for a single source of truth
        self._set_interrupts(interrupt_before, interrupt_after)

    async def _check_interrupted(
        self,
        state: StateT,
        input_data: dict[str, Any],
        config: dict[str, Any],
    ) -> dict[str, Any]:
        if state.is_interrupted():
            logger.info(
                "Resuming from interrupted state at node '%s'", state.execution_meta.current_node
            )
            # This is a resume case - clear interrupt and merge input data
            if input_data:
                config["resume_data"] = input_data
                logger.debug("Added resume data with %d keys", len(input_data))
            state.clear_interrupt()
        elif not input_data.get("messages") and not state.context:
            # This is a fresh execution - validate input data
            error_msg = "Input data must contain 'messages' for new execution."
            logger.error(error_msg)
            raise ValueError(error_msg)
        else:
            logger.info(
                "Starting fresh execution with %d messages", len(input_data.get("messages", []))
            )

        return config

    async def _check_and_handle_interrupt(
        self,
        current_node: str,
        interrupt_type: str,
        state: StateT,
        config: dict[str, Any],
    ) -> bool:
        """Check for interrupts and save state if needed. Returns True if interrupted."""
        interrupt_nodes: list[str] = (
            self.interrupt_before if interrupt_type == "before" else self.interrupt_after
        ) or []

        if current_node in interrupt_nodes:
            status = (
                ExecutionStatus.INTERRUPTED_BEFORE
                if interrupt_type == "before"
                else ExecutionStatus.INTERRUPTED_AFTER
            )
            state.set_interrupt(
                current_node,
                f"interrupt_{interrupt_type}: {current_node}",
                status,
            )
            # Save state and interrupt
            await sync_data(
                state=state,
                config=config,
                messages=[],
                trim=True,
            )
            logger.debug("Node '%s' interrupted", current_node)
            return True

        logger.debug(
            "No interrupts found for node '%s', continuing execution",
            current_node,
        )
        return False

    async def _check_stop_requested(
        self,
        state: StateT,
        current_node: str,
        event: EventModel,
        messages: list[Message],
        config: dict[str, Any],
    ) -> bool:
        """Check if a stop has been requested externally."""
        state = await reload_state(config, state)  # type: ignore

        # Check if a stop was requested externally (e.g., frontend)
        if state.is_stopped_requested():
            logger.info(
                "Stop requested for thread '%s' at node '%s'",
                config.get("thread_id"),
                current_node,
            )
            state.set_interrupt(
                current_node,
                "stop_requested",
                ExecutionStatus.INTERRUPTED_AFTER,
                data={"source": "stop", "info": "requested via is_stopped_requested"},
            )
            await sync_data(state=state, config=config, messages=messages, trim=True)
            event.event_type = EventType.INTERRUPTED
            event.metadata["interrupted"] = "Stop"
            event.metadata["status"] = "Graph execution stopped by request"
            event.data["state"] = state.model_dump()
            publish_event(event)
            return True
        return False

    async def _execute_graph(  # noqa: PLR0912, PLR0915
        self,
        state: StateT,
        config: dict[str, Any],
    ) -> tuple[StateT, list[Message]]:
        """Execute the entire graph with support for interrupts and resuming."""
        logger.info(
            "Starting graph execution from node '%s' at step %d",
            state.execution_meta.current_node,
            state.execution_meta.step,
        )
        logger.debug("DEBUG: Current node value: %r", state.execution_meta.current_node)
        logger.debug("DEBUG: END constant value: %r", END)
        logger.debug("DEBUG: Are they equal? %s", state.execution_meta.current_node == END)
        messages: list[Message] = []
        max_steps = config.get("recursion_limit", 25)
        logger.debug("Max steps limit set to %d", max_steps)

        # get the last message from state as that is human message
        last_human_message = state.context[-1] if state.context else None
        if last_human_message and last_human_message.role != "user":
            msg = [msg for msg in reversed(state.context) if msg.role == "user"]
            last_human_message = msg[0] if msg else None

        if last_human_message:
            logger.debug("Last human message: %s", last_human_message.content)
            messages.append(last_human_message)

        # Get current execution info from state
        current_node = state.execution_meta.current_node
        step = state.execution_meta.step

        # Create event for graph execution
        event = EventModel.default(
            config,
            data={"state": state.model_dump()},
            event=Event.GRAPH_EXECUTION,
            content_type=[ContentType.STATE],
            node_name=current_node,
            extra={
                "current_node": current_node,
                "step": step,
                "max_steps": max_steps,
            },
        )

        try:
            while current_node != END and step < max_steps:
                logger.debug("Executing step %d at node '%s'", step, current_node)
                # Reload state in each iteration to get latest (in case of external updates)
                res = await self._check_stop_requested(
                    state,
                    current_node,
                    event,
                    messages,
                    config,
                )
                if res:
                    return state, messages

                # Update execution metadata
                state.set_current_node(current_node)
                state.execution_meta.step = step
                await call_realtime_sync(state, config)
                event.data["state"] = state.model_dump()
                event.metadata["step"] = step
                event.metadata["current_node"] = current_node
                event.event_type = EventType.PROGRESS
                publish_event(event)

                # Check for interrupt_before
                if await self._check_and_handle_interrupt(
                    current_node,
                    "before",
                    state,
                    config,
                ):
                    logger.info("Graph execution interrupted before node '%s'", current_node)
                    event.event_type = EventType.INTERRUPTED
                    event.metadata["interrupted"] = "Before"
                    event.metadata["status"] = "Graph execution interrupted before node execution"
                    event.data["interrupted"] = "Before"
                    publish_event(event)
                    return state, messages

                # Execute current node
                logger.debug("Executing node '%s'", current_node)
                node = self.nodes[current_node]

                # Publish node invocation event

                ###############################################
                ##### Node Execution Started ##################
                ###############################################

                result = await node.execute(config, state)  # type: ignore

                ###############################################
                ##### Node Execution Finished #################
                ###############################################

                logger.debug("Node '%s' execution completed", current_node)

                next_node = None

                # Process result and get next node
                if isinstance(result, list):
                    # If result is a list of Message, append to messages
                    messages.extend(result)
                    logger.debug(
                        "Node '%s' returned %d messages, total messages now %d",
                        current_node,
                        len(result),
                        len(messages),
                    )
                    # Add messages to state context so they're visible to subsequent nodes
                    state.context = add_messages(state.context, result)

                # No state change beyond adding messages, just advance to next node
                if isinstance(result, dict):
                    state = result.get("state", state)
                    next_node = result.get("next_node")
                    new_messages = result.get("messages", [])
                    if new_messages:
                        messages.extend(new_messages)
                        logger.debug(
                            "Node '%s' returned %d messages, total messages now %d",
                            current_node,
                            len(new_messages),
                            len(messages),
                        )

                logger.debug(
                    "Node result processed, next_node=%s, total_messages=%d",
                    next_node,
                    len(messages),
                )

                # Check stop again after node execution
                res = await self._check_stop_requested(
                    state,
                    current_node,
                    event,
                    messages,
                    config,
                )
                if res:
                    return state, messages

                # Call realtime sync after node execution (if state/messages changed)
                await call_realtime_sync(state, config)
                event.event_type = EventType.UPDATE
                event.data["state"] = state.model_dump()
                event.data["messages"] = [m.model_dump() for m in messages] if messages else []
                if messages:
                    lm = messages[-1]
                    event.content = lm.text() if isinstance(lm.content, list) else lm.content
                    if isinstance(lm.content, list):
                        event.content_blocks = lm.content
                event.content_type = [ContentType.STATE, ContentType.MESSAGE]
                publish_event(event)

                # Check for interrupt_after
                if await self._check_and_handle_interrupt(
                    current_node,
                    "after",
                    state,
                    config,
                ):
                    logger.info("Graph execution interrupted after node '%s'", current_node)
                    # For interrupt_after, advance to next node before pausing
                    if next_node is None:
                        next_node = get_next_node(current_node, state, self.edges)
                    state.set_current_node(next_node)

                    event.event_type = EventType.INTERRUPTED
                    event.data["interrupted"] = "After"
                    event.metadata["interrupted"] = "After"
                    event.data["state"] = state.model_dump()
                    publish_event(event)
                    return state, messages

                # Get next node (only if no explicit navigation from Command)
                if next_node is None:
                    current_node = get_next_node(current_node, state, self.edges)
                    logger.debug("Next node determined by graph logic: '%s'", current_node)
                else:
                    current_node = next_node
                    logger.debug("Next node determined by command: '%s'", current_node)

                # Check if we've reached the end after determining next node
                logger.debug("Checking if current_node '%s' == END '%s'", current_node, END)
                if current_node == END:
                    logger.info("Graph execution reached END node, completing")
                    break

                # Advance step after successful node execution
                step += 1
                state.advance_step()
                await call_realtime_sync(state, config)
                event.event_type = EventType.UPDATE

                event.metadata["State_Updated"] = "State Updated"
                event.data["state"] = state.model_dump()
                publish_event(event)

                if step >= max_steps:
                    error_msg = "Graph execution exceeded maximum steps"
                    logger.error(error_msg)
                    state.error(error_msg)
                    await call_realtime_sync(state, config)
                    event.event_type = EventType.ERROR
                    event.data["state"] = state.model_dump()
                    event.metadata["error"] = error_msg
                    event.metadata["step"] = step
                    event.metadata["current_node"] = current_node

                    publish_event(event)
                    raise GraphRecursionError(
                        f"Graph execution exceeded recursion limit: {max_steps}"
                    )

            # Execution completed successfully
            logger.info(
                "Graph execution completed successfully at node '%s' after %d steps",
                current_node,
                step,
            )
            state.complete()
            res = await sync_data(
                state=state,
                config=config,
                messages=messages,
                trim=True,
            )
            event.event_type = EventType.END
            event.data["state"] = state.model_dump()
            event.data["messages"] = [m.model_dump() for m in messages] if messages else []
            if messages:
                fm = messages[-1]
                event.content = fm.text() if isinstance(fm.content, list) else fm.content
                if isinstance(fm.content, list):
                    event.content_blocks = fm.content
            event.content_type = [ContentType.STATE, ContentType.MESSAGE]
            event.metadata["status"] = "Graph execution completed"
            event.metadata["step"] = step
            event.metadata["current_node"] = current_node
            event.metadata["is_context_trimmed"] = res

            publish_event(event)

            return state, messages

        except Exception as e:
            # Handle execution errors
            logger.exception("Graph execution failed: %s", e)
            state.error(str(e))

            # Publish error event
            event.event_type = EventType.ERROR
            event.metadata["error"] = str(e)
            event.data["state"] = state.model_dump()
            publish_event(event)

            await sync_data(
                state=state,
                config=config,
                messages=messages,
                trim=True,
            )
            raise

    async def invoke(
        self,
        input_data: dict[str, Any],
        config: dict[str, Any],
        default_state: StateT,
        response_granularity: ResponseGranularity = ResponseGranularity.LOW,
    ):
        """Execute the graph asynchronously with event publishing."""
        logger.info(
            "Starting asynchronous graph execution with %d input keys, granularity=%s",
            len(input_data) if input_data else 0,
            response_granularity,
        )
        input_data = input_data or {}

        # Load or initialize state
        logger.debug("Loading or creating state from input data")
        new_state = await load_or_create_state(
            input_data,
            config,
            default_state,
        )
        state: StateT = new_state  # type: ignore[assignment]
        logger.debug(
            "State loaded: interrupted=%s, current_node=%s, step=%d",
            state.is_interrupted(),
            state.execution_meta.current_node,
            state.execution_meta.step,
        )

        # Event publishing logic
        event = EventModel.default(
            config,
            data={"state": state.model_dump()},
            event=Event.GRAPH_EXECUTION,
            content_type=[ContentType.STATE],
            node_name=state.execution_meta.current_node,
            extra={
                "current_node": state.execution_meta.current_node,
                "step": state.execution_meta.step,
            },
        )
        event.event_type = EventType.START
        publish_event(event)

        # Check if this is a resume case
        config = await self._check_interrupted(state, input_data, config)

        event.event_type = EventType.UPDATE
        event.metadata["status"] = "Graph invoked"
        publish_event(event)

        try:
            logger.debug("Beginning graph execution")
            event.event_type = EventType.PROGRESS
            event.metadata["status"] = "Graph execution started"
            publish_event(event)

            final_state, messages = await self._execute_graph(state, config)
            logger.info("Graph execution completed with %d final messages", len(messages))

            event.event_type = EventType.END
            event.metadata["status"] = "Graph execution completed"
            event.data["state"] = final_state.model_dump()
            event.data["messages"] = [m.model_dump() for m in messages] if messages else []
            publish_event(event)

            return await parse_response(
                final_state,
                messages,
                response_granularity,
            )
        except Exception as e:
            logger.exception("Graph execution failed: %s", e)
            event.event_type = EventType.ERROR
            event.metadata["status"] = f"Graph execution failed: {e}"
            event.data["error"] = str(e)
            publish_event(event)
            raise
Attributes
edges instance-attribute
edges = edges
interrupt_after instance-attribute
interrupt_after = interrupt_after or []
interrupt_before instance-attribute
interrupt_before = interrupt_before or []
nodes instance-attribute
nodes = nodes
Functions
__init__
__init__(nodes, edges, interrupt_before=None, interrupt_after=None)
Source code in pyagenity/graph/utils/invoke_handler.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
@inject
def __init__(
    self,
    nodes: dict[str, Node],
    edges: list[Edge],
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
):
    self.nodes: dict[str, Node] = nodes
    self.edges: list[Edge] = edges
    # Keep existing attributes for backward-compatibility
    self.interrupt_before = interrupt_before or []
    self.interrupt_after = interrupt_after or []
    # And set via mixin for a single source of truth
    self._set_interrupts(interrupt_before, interrupt_after)
invoke async
invoke(input_data, config, default_state, response_granularity=ResponseGranularity.LOW)

Execute the graph asynchronously with event publishing.

Source code in pyagenity/graph/utils/invoke_handler.py
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
async def invoke(
    self,
    input_data: dict[str, Any],
    config: dict[str, Any],
    default_state: StateT,
    response_granularity: ResponseGranularity = ResponseGranularity.LOW,
):
    """Execute the graph asynchronously with event publishing."""
    logger.info(
        "Starting asynchronous graph execution with %d input keys, granularity=%s",
        len(input_data) if input_data else 0,
        response_granularity,
    )
    input_data = input_data or {}

    # Load or initialize state
    logger.debug("Loading or creating state from input data")
    new_state = await load_or_create_state(
        input_data,
        config,
        default_state,
    )
    state: StateT = new_state  # type: ignore[assignment]
    logger.debug(
        "State loaded: interrupted=%s, current_node=%s, step=%d",
        state.is_interrupted(),
        state.execution_meta.current_node,
        state.execution_meta.step,
    )

    # Event publishing logic
    event = EventModel.default(
        config,
        data={"state": state.model_dump()},
        event=Event.GRAPH_EXECUTION,
        content_type=[ContentType.STATE],
        node_name=state.execution_meta.current_node,
        extra={
            "current_node": state.execution_meta.current_node,
            "step": state.execution_meta.step,
        },
    )
    event.event_type = EventType.START
    publish_event(event)

    # Check if this is a resume case
    config = await self._check_interrupted(state, input_data, config)

    event.event_type = EventType.UPDATE
    event.metadata["status"] = "Graph invoked"
    publish_event(event)

    try:
        logger.debug("Beginning graph execution")
        event.event_type = EventType.PROGRESS
        event.metadata["status"] = "Graph execution started"
        publish_event(event)

        final_state, messages = await self._execute_graph(state, config)
        logger.info("Graph execution completed with %d final messages", len(messages))

        event.event_type = EventType.END
        event.metadata["status"] = "Graph execution completed"
        event.data["state"] = final_state.model_dump()
        event.data["messages"] = [m.model_dump() for m in messages] if messages else []
        publish_event(event)

        return await parse_response(
            final_state,
            messages,
            response_granularity,
        )
    except Exception as e:
        logger.exception("Graph execution failed: %s", e)
        event.event_type = EventType.ERROR
        event.metadata["status"] = f"Graph execution failed: {e}"
        event.data["error"] = str(e)
        publish_event(event)
        raise
Functions
invoke_node_handler

InvokeNodeHandler utilities for PyAgenity agent graph execution.

This module provides the InvokeNodeHandler class, which manages the invocation of node functions and tool nodes within the agent graph. It supports dependency injection, callback hooks, event publishing, and error recovery for both regular and tool-based nodes.

Classes:

Name Description
InvokeNodeHandler

Handles execution of node functions and tool nodes with DI and callbacks.

Usage

handler = InvokeNodeHandler(name, func, publisher) result = await handler.invoke(config, state)

Attributes:

Name Type Description
logger
Attributes
logger module-attribute
logger = getLogger(__name__)
Classes
InvokeNodeHandler

Bases: BaseLoggingMixin

Handles invocation of node functions and tool nodes in the agent graph.

Supports dependency injection, callback hooks, event publishing, and error recovery.

Parameters:

Name Type Description Default
name str

Name of the node.

required
func Callable | ToolNode

The function or ToolNode to execute.

required
publisher BasePublisher

Event publisher for execution events.

Inject[BasePublisher]

Methods:

Name Description
__init__
clear_signature_cache

Clear the function signature cache. Useful for testing or memory management.

invoke

Execute the node function or ToolNode with dependency injection and callback hooks.

Attributes:

Name Type Description
func
name
publisher
Source code in pyagenity/graph/utils/invoke_node_handler.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
class InvokeNodeHandler(BaseLoggingMixin):
    """
    Handles invocation of node functions and tool nodes in the agent graph.

    Supports dependency injection, callback hooks, event publishing, and error recovery.

    Args:
        name (str): Name of the node.
        func (Callable | ToolNode): The function or ToolNode to execute.
        publisher (BasePublisher, optional): Event publisher for execution events.
    """

    # Class-level cache for function signatures to avoid repeated inspection
    _signature_cache: dict[Callable, inspect.Signature] = {}

    @classmethod
    def clear_signature_cache(cls) -> None:
        """Clear the function signature cache. Useful for testing or memory management."""
        cls._signature_cache.clear()

    def __init__(
        self,
        name: str,
        func: Union[Callable, "ToolNode"],
        publisher: BasePublisher | None = Inject[BasePublisher],
    ):
        self.name = name
        self.func = func
        self.publisher = publisher

    async def _handle_single_tool(
        self,
        tool_call: dict[str, Any],
        state: AgentState,
        config: dict[str, Any],
    ) -> Message:
        """
        Execute a single tool call using the ToolNode.

        Args:
            tool_call (dict): Tool call specification.
            state (AgentState): Current agent state.
            config (dict): Node configuration.

        Returns:
            Message: Resulting message from tool execution.
        """
        function_name = tool_call.get("function", {}).get("name", "")
        function_args: dict = json.loads(tool_call.get("function", {}).get("arguments", "{}"))
        tool_call_id = tool_call.get("id", "")

        logger.info(
            "Node '%s' executing tool '%s' with %d arguments",
            self.name,
            function_name,
            len(function_args),
        )
        logger.debug("Tool arguments: %s", function_args)

        # Execute the tool function with injectable parameters
        tool_result = await self.func.invoke(  # type: ignore
            function_name,  # type: ignore
            function_args,
            tool_call_id=tool_call_id,
            state=state,
            config=config,
        )
        logger.debug("Node '%s' tool execution completed successfully", self.name)

        return tool_result

    async def _call_tools(
        self,
        last_message: Message,
        state: "AgentState",
        config: dict[str, Any],
    ) -> list[Message]:
        """
        Execute all tool calls present in the last message.

        Args:
            last_message (Message): The last message containing tool calls.
            state (AgentState): Current agent state.
            config (dict): Node configuration.

        Returns:
            list[Message]: List of messages from tool executions.

        Raises:
            NodeError: If no tool calls are present.
        """
        logger.debug("Node '%s' calling tools from message", self.name)
        result: list[Message] = []
        if (
            hasattr(last_message, "tools_calls")
            and last_message.tools_calls
            and len(last_message.tools_calls) > 0
        ):
            # Execute the first tool call for now
            tool_call = last_message.tools_calls[0]
            for tool_call in last_message.tools_calls:
                res = await self._handle_single_tool(
                    tool_call,
                    state,
                    config,
                )
                result.append(res)
        else:
            # No tool calls to execute, return available tools
            logger.exception("Node '%s': No tool calls to execute", self.name)
            raise NodeError("No tool calls to execute")

        return result

    def _get_cached_signature(self, func: Callable) -> inspect.Signature:
        """Get cached signature for a function, computing it if not cached."""
        if func not in self._signature_cache:
            self._signature_cache[func] = inspect.signature(func)
        return self._signature_cache[func]

    def _prepare_input_data(
        self,
        state: "AgentState",
        config: dict[str, Any],
    ) -> dict:
        """
        Prepare input data for function invocation, handling injectable parameters.
        Uses cached function signature to avoid repeated inspection overhead.

        Args:
            state (AgentState): Current agent state.
            config (dict): Node configuration.

        Returns:
            dict: Input data for function call.

        Raises:
            TypeError: If required parameters are missing.
        """
        # Use cached signature inspection for performance
        sig = self._get_cached_signature(self.func)  # type: ignore Tool node won't come here
        input_data = {}
        default_data = {
            "state": state,
            "config": config,
        }

        # # Get injectable parameters to determine which ones to exclude from manual passing
        # # Prepare function arguments (excluding injectable parameters)
        for param_name, param in sig.parameters.items():
            # Skip *args/**kwargs
            if param.kind in (
                inspect.Parameter.VAR_POSITIONAL,
                inspect.Parameter.VAR_KEYWORD,
            ):
                continue

            # check its state, config
            if param_name in ["state", "config"]:
                input_data[param_name] = default_data[param_name]
            # Include regular function arguments
            elif param.default is inspect.Parameter.empty:
                raise TypeError(
                    f"Missing required parameter '{param_name}' for function '{self.func}'"
                )

        return input_data

    async def _call_normal_node(
        self,
        state: "AgentState",
        config: dict[str, Any],
        callback_mgr: CallbackManager,
    ) -> dict[str, Any]:
        """
        Execute a regular node function with callback hooks and event publishing.

        Args:
            state (AgentState): Current agent state.
            config (dict): Node configuration.
            callback_mgr (CallbackManager): Callback manager for hooks.

        Returns:
            dict: Result containing new state, messages, and next node.

        Raises:
            Exception: If function execution fails and cannot be recovered.
        """
        logger.debug("Node '%s' calling normal function", self.name)
        result: dict[str, Any] = {}

        logger.debug("Node '%s' is a regular function, executing with callbacks", self.name)
        # This is a regular function - likely AI function
        # Create callback context for AI invocation
        context = CallbackContext(
            invocation_type=InvocationType.AI,
            node_name=self.name,
            function_name=getattr(self.func, "__name__", str(self.func)),
            metadata={"config": config},
        )

        # Event publishing logic (similar to stream_node_handler)

        input_data = self._prepare_input_data(
            state,
            config,
        )

        last_message = state.context[-1] if state.context and len(state.context) > 0 else None

        event = EventModel.default(
            config,
            data={"state": state.model_dump()},
            event=Event.NODE_EXECUTION,
            content_type=[ContentType.STATE],
            node_name=self.name,
            extra={
                "node": self.name,
                "function_name": getattr(self.func, "__name__", str(self.func)),
                "last_message": last_message.model_dump() if last_message else None,
            },
        )
        publish_event(event)

        try:
            logger.debug("Node '%s' executing before_invoke callbacks", self.name)
            # Execute before_invoke callbacks
            input_data = await callback_mgr.execute_before_invoke(context, input_data)
            logger.debug("Node '%s' executing function", self.name)
            event.event_type = EventType.PROGRESS
            event.metadata["status"] = "Function execution started"
            publish_event(event)

            # Execute the actual function
            result = await call_sync_or_async(
                self.func,  # type: ignore
                **input_data,
            )
            logger.debug("Node '%s' function execution completed", self.name)

            logger.debug("Node '%s' executing after_invoke callbacks", self.name)
            # Execute after_invoke callbacks
            result = await callback_mgr.execute_after_invoke(context, input_data, result)

            # Process result and publish END event
            messages = []
            new_state, messages, next_node = await process_node_result(result, state, messages)
            event.data["state"] = new_state.model_dump()
            event.event_type = EventType.END
            event.metadata["status"] = "Function execution completed"
            event.data["messages"] = [m.model_dump() for m in messages] if messages else []
            event.data["next_node"] = next_node
            # mirror simple content + structured blocks for the last message
            if messages:
                last = messages[-1]
                event.content = last.text() if isinstance(last.content, list) else last.content
                if isinstance(last.content, list):
                    event.content_blocks = last.content

            publish_event(event)

            return {
                "state": new_state,
                "messages": messages,
                "next_node": next_node,
            }

        except Exception as e:
            logger.warning(
                "Node '%s' execution failed, executing error callbacks: %s", self.name, e
            )
            # Execute error callbacks
            recovery_result = await callback_mgr.execute_on_error(context, input_data, e)

            if recovery_result is not None:
                logger.info(
                    "Node '%s' recovered from error using callback result",
                    self.name,
                )
                # Use recovery result instead of raising the error
                event.event_type = EventType.END
                event.metadata["status"] = "Function execution recovered from error"
                event.data["message"] = recovery_result.model_dump()
                event.content_type = [ContentType.MESSAGE, ContentType.STATE]
                publish_event(event)
                return {
                    "state": state,
                    "messages": [recovery_result],
                    "next_node": None,
                }
            # Re-raise the original error
            logger.error("Node '%s' could not recover from error", self.name)
            event.event_type = EventType.ERROR
            event.metadata["status"] = f"Function execution failed: {e}"
            event.data["error"] = str(e)
            event.content_type = [ContentType.ERROR, ContentType.STATE]
            publish_event(event)
            raise

    async def invoke(
        self,
        config: dict[str, Any],
        state: AgentState,
        callback_mgr: CallbackManager = Inject[CallbackManager],
    ) -> dict[str, Any] | list[Message]:
        """
        Execute the node function or ToolNode with dependency injection and callback hooks.

        Args:
            config (dict): Node configuration.
            state (AgentState): Current agent state.
            callback_mgr (CallbackManager, optional): Callback manager for hooks.

        Returns:
            dict | list[Message]: Result of node execution (regular node or tool node).

        Raises:
            NodeError: If execution fails or context is missing for tool nodes.
        """
        logger.info("Executing node '%s'", self.name)
        logger.debug(
            "Node '%s' execution with state context size=%d, config keys=%s",
            self.name,
            len(state.context) if state.context else 0,
            list(config.keys()) if config else [],
        )

        try:
            if isinstance(self.func, ToolNode):
                logger.debug("Node '%s' is a ToolNode, executing tool calls", self.name)
                # This is tool execution - handled separately in ToolNode
                if state.context and len(state.context) > 0:
                    last_message = state.context[-1]
                    logger.debug("Node '%s' processing tool calls from last message", self.name)
                    result = await self._call_tools(
                        last_message,
                        state,
                        config,
                    )
                else:
                    # No context, return available tools
                    error_msg = "No context available for tool execution"
                    logger.error("Node '%s': %s", self.name, error_msg)
                    raise NodeError(error_msg)

            else:
                result = await self._call_normal_node(
                    state,
                    config,
                    callback_mgr,
                )

            logger.info("Node '%s' execution completed successfully", self.name)
            return result
        except Exception as e:
            # This is the final catch-all for node execution errors
            logger.exception("Node '%s' execution failed: %s", self.name, e)
            raise NodeError(f"Error in node '{self.name}': {e!s}") from e
Attributes
func instance-attribute
func = func
name instance-attribute
name = name
publisher instance-attribute
publisher = publisher
Functions
__init__
__init__(name, func, publisher=Inject[BasePublisher])
Source code in pyagenity/graph/utils/invoke_node_handler.py
65
66
67
68
69
70
71
72
73
def __init__(
    self,
    name: str,
    func: Union[Callable, "ToolNode"],
    publisher: BasePublisher | None = Inject[BasePublisher],
):
    self.name = name
    self.func = func
    self.publisher = publisher
clear_signature_cache classmethod
clear_signature_cache()

Clear the function signature cache. Useful for testing or memory management.

Source code in pyagenity/graph/utils/invoke_node_handler.py
60
61
62
63
@classmethod
def clear_signature_cache(cls) -> None:
    """Clear the function signature cache. Useful for testing or memory management."""
    cls._signature_cache.clear()
invoke async
invoke(config, state, callback_mgr=Inject[CallbackManager])

Execute the node function or ToolNode with dependency injection and callback hooks.

Parameters:

Name Type Description Default
config dict

Node configuration.

required
state AgentState

Current agent state.

required
callback_mgr CallbackManager

Callback manager for hooks.

Inject[CallbackManager]

Returns:

Type Description
dict[str, Any] | list[Message]

dict | list[Message]: Result of node execution (regular node or tool node).

Raises:

Type Description
NodeError

If execution fails or context is missing for tool nodes.

Source code in pyagenity/graph/utils/invoke_node_handler.py
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
async def invoke(
    self,
    config: dict[str, Any],
    state: AgentState,
    callback_mgr: CallbackManager = Inject[CallbackManager],
) -> dict[str, Any] | list[Message]:
    """
    Execute the node function or ToolNode with dependency injection and callback hooks.

    Args:
        config (dict): Node configuration.
        state (AgentState): Current agent state.
        callback_mgr (CallbackManager, optional): Callback manager for hooks.

    Returns:
        dict | list[Message]: Result of node execution (regular node or tool node).

    Raises:
        NodeError: If execution fails or context is missing for tool nodes.
    """
    logger.info("Executing node '%s'", self.name)
    logger.debug(
        "Node '%s' execution with state context size=%d, config keys=%s",
        self.name,
        len(state.context) if state.context else 0,
        list(config.keys()) if config else [],
    )

    try:
        if isinstance(self.func, ToolNode):
            logger.debug("Node '%s' is a ToolNode, executing tool calls", self.name)
            # This is tool execution - handled separately in ToolNode
            if state.context and len(state.context) > 0:
                last_message = state.context[-1]
                logger.debug("Node '%s' processing tool calls from last message", self.name)
                result = await self._call_tools(
                    last_message,
                    state,
                    config,
                )
            else:
                # No context, return available tools
                error_msg = "No context available for tool execution"
                logger.error("Node '%s': %s", self.name, error_msg)
                raise NodeError(error_msg)

        else:
            result = await self._call_normal_node(
                state,
                config,
                callback_mgr,
            )

        logger.info("Node '%s' execution completed successfully", self.name)
        return result
    except Exception as e:
        # This is the final catch-all for node execution errors
        logger.exception("Node '%s' execution failed: %s", self.name, e)
        raise NodeError(f"Error in node '{self.name}': {e!s}") from e
Functions
stream_handler

Streaming graph execution handler for PyAgenity workflows.

This module provides the StreamHandler class, which manages the execution of graph workflows with support for streaming output, interrupts, state persistence, and event publishing. It enables incremental result processing, pause/resume capabilities, and robust error handling for agent workflows that require real-time or chunked responses.

Classes:

Name Description
StreamHandler

Handles streaming execution for graph workflows in PyAgenity.

Attributes:

Name Type Description
StateT
logger
Attributes
StateT module-attribute
StateT = TypeVar('StateT', bound=AgentState)
logger module-attribute
logger = getLogger(__name__)
Classes
StreamHandler

Bases: BaseLoggingMixin, InterruptConfigMixin

Handles streaming execution for graph workflows in PyAgenity.

StreamHandler manages the execution of agent workflows as directed graphs, supporting streaming output, pause/resume via interrupts, state persistence, and event publishing for monitoring and debugging. It enables incremental result processing and robust error handling for complex agent workflows.

Attributes:

Name Type Description
nodes dict[str, Node]

Dictionary mapping node names to Node instances.

edges list[Edge]

List of Edge instances defining graph connections and routing.

interrupt_before

List of node names where execution should pause before execution.

interrupt_after

List of node names where execution should pause after execution.

Example
handler = StreamHandler(nodes, edges)
async for chunk in handler.stream(input_data, config, state):
    print(chunk)

Methods:

Name Description
__init__
stream

Execute the graph asynchronously with streaming output.

Source code in pyagenity/graph/utils/stream_handler.py
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
class StreamHandler[StateT: AgentState](
    BaseLoggingMixin,
    InterruptConfigMixin,
):
    """Handles streaming execution for graph workflows in PyAgenity.

    StreamHandler manages the execution of agent workflows as directed graphs,
    supporting streaming output, pause/resume via interrupts, state persistence,
    and event publishing for monitoring and debugging. It enables incremental
    result processing and robust error handling for complex agent workflows.

    Attributes:
        nodes: Dictionary mapping node names to Node instances.
        edges: List of Edge instances defining graph connections and routing.
        interrupt_before: List of node names where execution should pause before execution.
        interrupt_after: List of node names where execution should pause after execution.

    Example:
        ```python
        handler = StreamHandler(nodes, edges)
        async for chunk in handler.stream(input_data, config, state):
            print(chunk)
        ```
    """

    @inject
    def __init__(
        self,
        nodes: dict[str, Node],
        edges: list[Edge],
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
    ):
        self.nodes: dict[str, Node] = nodes
        self.edges: list[Edge] = edges
        self.interrupt_before = interrupt_before or []
        self.interrupt_after = interrupt_after or []
        self._set_interrupts(interrupt_before, interrupt_after)

    async def _check_interrupted(
        self,
        state: StateT,
        input_data: dict[str, Any],
        config: dict[str, Any],
    ) -> dict[str, Any]:
        if state.is_interrupted():
            logger.info(
                "Resuming from interrupted state at node '%s'", state.execution_meta.current_node
            )
            # This is a resume case - clear interrupt and merge input data
            if input_data:
                config["resume_data"] = input_data
                logger.debug("Added resume data with %d keys", len(input_data))
            state.clear_interrupt()
        elif not input_data.get("messages") and not state.context:
            # This is a fresh execution - validate input data
            error_msg = "Input data must contain 'messages' for new execution."
            logger.error(error_msg)
            raise ValueError(error_msg)
        else:
            logger.info(
                "Starting fresh execution with %d messages", len(input_data.get("messages", []))
            )

        return config

    async def _check_and_handle_interrupt(
        self,
        current_node: str,
        interrupt_type: str,
        state: StateT,
        config: dict[str, Any],
    ) -> bool:
        """Check for interrupts and save state if needed. Returns True if interrupted."""
        interrupt_nodes: list[str] = (
            self.interrupt_before if interrupt_type == "before" else self.interrupt_after
        ) or []

        if current_node in interrupt_nodes:
            status = (
                ExecutionStatus.INTERRUPTED_BEFORE
                if interrupt_type == "before"
                else ExecutionStatus.INTERRUPTED_AFTER
            )
            state.set_interrupt(
                current_node,
                f"interrupt_{interrupt_type}: {current_node}",
                status,
            )
            # Save state and interrupt
            await sync_data(
                state=state,
                config=config,
                messages=[],
                trim=True,
            )
            logger.debug("Node '%s' interrupted", current_node)
            return True

        logger.debug(
            "No interrupts found for node '%s', continuing execution",
            current_node,
        )
        return False

    async def _check_stop_requested(
        self,
        state: StateT,
        current_node: str,
        event: EventModel,
        messages: list[Message],
        config: dict[str, Any],
    ) -> bool:
        """Check if a stop has been requested externally."""
        state = await reload_state(config, state)  # type: ignore

        # Check if a stop was requested externally (e.g., frontend)
        if state.is_stopped_requested():
            logger.info(
                "Stop requested for thread '%s' at node '%s'",
                config.get("thread_id"),
                current_node,
            )
            state.set_interrupt(
                current_node,
                "stop_requested",
                ExecutionStatus.INTERRUPTED_AFTER,
                data={"source": "stop", "info": "requested via is_stopped_requested"},
            )
            await sync_data(state=state, config=config, messages=messages, trim=True)
            event.event_type = EventType.INTERRUPTED
            event.metadata["interrupted"] = "Stop"
            event.metadata["status"] = "Graph execution stopped by request"
            event.data["state"] = state.model_dump()
            publish_event(event)
            return True
        return False

    async def _execute_graph(  # noqa: PLR0912, PLR0915
        self,
        state: StateT,
        input_data: dict[str, Any],
        config: dict[str, Any],
    ) -> AsyncIterable[Message]:
        """
        Execute the entire graph with support for interrupts and resuming.

        Why so many chunks are yielded?
        We allow user to set response type, if they want low granularity
        Only few chunks like Message will be sent to user
        """
        logger.info(
            "Starting graph execution from node '%s' at step %d",
            state.execution_meta.current_node,
            state.execution_meta.step,
        )
        messages: list[Message] = []
        messages_ids = set()
        max_steps = config.get("recursion_limit", 25)
        logger.debug("Max steps limit set to %d", max_steps)

        last_human_messages = input_data.get("messages", []) or []
        # Stream initial input messages (e.g., human messages) so callers see full conversation
        # Only emit when present and avoid duplicates by tracking message_ids and existing context
        for m in last_human_messages:
            if m.message_id not in messages_ids:
                messages.append(m)
                messages_ids.add(m.message_id)
                yield m

        # Get current execution info from state
        current_node = state.execution_meta.current_node
        step = state.execution_meta.step

        # Create event for graph execution
        event = EventModel.default(
            config,
            data={"state": state.model_dump(exclude={"execution_meta"})},
            content_type=[ContentType.STATE],
            extra={"step": step, "current_node": current_node},
            event=Event.GRAPH_EXECUTION,
            node_name=current_node,
        )

        try:
            while current_node != END and step < max_steps:
                logger.debug("Executing step %d at node '%s'", step, current_node)

                res = await self._check_stop_requested(
                    state,
                    current_node,
                    event,
                    messages,
                    config,
                )
                if res:
                    return

                # Update execution metadata
                state.set_current_node(current_node)
                state.execution_meta.step = step
                await call_realtime_sync(state, config)

                # Update event with current step info
                event.data["step"] = step
                event.data["current_node"] = current_node
                event.event_type = EventType.PROGRESS
                event.metadata["status"] = f"Executing step {step} at node '{current_node}'"
                publish_event(event)

                # Check for interrupt_before
                if await self._check_and_handle_interrupt(
                    current_node,
                    "before",
                    state,
                    config,
                ):
                    logger.info("Graph execution interrupted before node '%s'", current_node)
                    event.event_type = EventType.INTERRUPTED
                    event.metadata["status"] = "Graph execution interrupted before node execution"
                    event.metadata["interrupted"] = "Before"
                    event.data["interrupted"] = "Before"
                    publish_event(event)
                    return

                # Execute current node
                logger.debug("Executing node '%s'", current_node)
                node = self.nodes[current_node]

                # Node execution
                result = node.stream(config, state)  # type: ignore

                logger.debug("Node '%s' execution completed", current_node)

                res = await self._check_stop_requested(
                    state,
                    current_node,
                    event,
                    messages,
                    config,
                )
                if res:
                    return

                # Process result and get next node
                next_node = None
                async for rs in result:
                    # Allow stop to break inner result loop as well
                    if isinstance(rs, Message) and rs.delta:
                        # Yield delta messages immediately for streaming
                        yield rs

                    elif isinstance(rs, Message) and not rs.delta:
                        yield rs

                        if rs.message_id not in messages_ids:
                            messages.append(rs)
                            messages_ids.add(rs.message_id)

                    elif isinstance(rs, dict) and "is_non_streaming" in rs:
                        if rs["is_non_streaming"]:
                            state = rs.get("state", state)
                            new_messages = rs.get("messages", [])
                            for m in new_messages:
                                if m.message_id not in messages_ids and not m.delta:
                                    messages.append(m)
                                    messages_ids.add(m.message_id)
                                yield m
                            next_node = rs.get("next_node", next_node)
                        else:
                            # Streaming path completed: ensure any collected messages are persisted
                            new_messages = rs.get("messages", [])
                            for m in new_messages:
                                if m.message_id not in messages_ids and not m.delta:
                                    messages.append(m)
                                    messages_ids.add(m.message_id)
                                    yield m
                            next_node = rs.get("next_node", next_node)
                    else:
                        # Process as node result (non-streaming path)
                        try:
                            state, new_messages, next_node = await process_node_result(
                                rs,
                                state,
                                [],
                            )
                            for m in new_messages:
                                if m.message_id not in messages_ids and not m.delta:
                                    messages.append(m)
                                    messages_ids.add(m.message_id)
                                    state.context = add_messages(state.context, [m])
                                    yield m
                        except Exception as e:
                            logger.error("Failed to process node result: %s", e)

                logger.debug(
                    "Node result processed, next_node=%s, total_messages=%d",
                    next_node,
                    len(messages),
                )

                # Add collected messages to state context
                if messages:
                    state.context = add_messages(state.context, messages)
                    logger.debug("Added %d messages to state context", len(messages))

                # Call realtime sync after node execution
                await call_realtime_sync(state, config)
                event.event_type = EventType.UPDATE
                event.data["state"] = state.model_dump()
                event.data["messages"] = [m.model_dump() for m in messages] if messages else []
                if messages:
                    lm = messages[-1]
                    event.content = lm.text() if isinstance(lm.content, list) else lm.content
                    if isinstance(lm.content, list):
                        event.content_blocks = lm.content
                event.content_type = [ContentType.STATE, ContentType.MESSAGE]
                publish_event(event)

                # Check for interrupt_after
                if await self._check_and_handle_interrupt(
                    current_node,
                    "after",
                    state,
                    config,
                ):
                    logger.info("Graph execution interrupted after node '%s'", current_node)
                    # For interrupt_after, advance to next node before pausing
                    if next_node is None:
                        next_node = get_next_node(current_node, state, self.edges)
                    state.set_current_node(next_node)

                    event.event_type = EventType.INTERRUPTED
                    event.data["interrupted"] = "After"
                    event.metadata["interrupted"] = "After"
                    event.data["state"] = state.model_dump()
                    publish_event(event)
                    return

                # Get next node
                if next_node is None:
                    current_node = get_next_node(current_node, state, self.edges)
                    logger.debug("Next node determined by graph logic: '%s'", current_node)
                else:
                    current_node = next_node
                    logger.debug("Next node determined by command: '%s'", current_node)

                # Advance step after successful node execution
                step += 1
                state.advance_step()
                await call_realtime_sync(state, config)

                event.event_type = EventType.UPDATE
                event.metadata["State_Updated"] = "State Updated"
                event.data["state"] = state.model_dump()
                publish_event(event)

                if step >= max_steps:
                    error_msg = "Graph execution exceeded maximum steps"
                    logger.error(error_msg)
                    state.error(error_msg)
                    await call_realtime_sync(state, config)

                    event.event_type = EventType.ERROR
                    event.data["state"] = state.model_dump()
                    event.metadata["error"] = error_msg
                    event.metadata["step"] = step
                    event.metadata["current_node"] = current_node
                    publish_event(event)

                    yield Message(
                        role="assistant",
                        content=[ErrorBlock(text=error_msg)],  # type: ignore
                    )

                    raise GraphRecursionError(
                        f"Graph execution exceeded recursion limit: {max_steps}"
                    )

            # Execution completed successfully
            logger.info(
                "Graph execution completed successfully at node '%s' after %d steps",
                current_node,
                step,
            )
            state.complete()
            is_context_trimmed = await sync_data(
                state=state,
                config=config,
                messages=messages,
                trim=True,
            )

            # Create completion event
            event.event_type = EventType.END
            event.data["state"] = state.model_dump()
            event.data["messages"] = [m.model_dump() for m in messages] if messages else []
            if messages:
                fm = messages[-1]
                event.content = fm.text() if isinstance(fm.content, list) else fm.content
                if isinstance(fm.content, list):
                    event.content_blocks = fm.content
            event.content_type = [ContentType.STATE, ContentType.MESSAGE]
            event.metadata["status"] = "Graph execution completed"
            event.metadata["step"] = step
            event.metadata["current_node"] = current_node
            event.metadata["is_context_trimmed"] = is_context_trimmed
            publish_event(event)

        except Exception as e:
            # Handle execution errors
            logger.exception("Graph execution failed: %s", e)
            state.error(str(e))

            # Publish error event
            event.event_type = EventType.ERROR
            event.metadata["error"] = str(e)
            event.data["state"] = state.model_dump()
            publish_event(event)

            await sync_data(
                state=state,
                config=config,
                messages=messages,
                trim=True,
            )
            raise

    async def stream(
        self,
        input_data: dict[str, Any],
        config: dict[str, Any],
        default_state: StateT,
        response_granularity: ResponseGranularity = ResponseGranularity.LOW,
    ) -> AsyncGenerator[Message]:
        """Execute the graph asynchronously with streaming output.

        Runs the graph workflow from start to finish, yielding incremental results
        as they become available. Automatically detects whether to start a fresh
        execution or resume from an interrupted state, supporting pause/resume
        and checkpointing.

        Args:
            input_data: Input dictionary for graph execution. For new executions,
                should contain 'messages' key with initial messages. For resumed
                executions, can contain additional data to merge.
            config: Configuration dictionary containing execution settings and context.
            default_state: Initial or template AgentState for workflow execution.
            response_granularity: Level of detail in the response (LOW, PARTIAL, FULL).

        Yields:
            Message objects representing incremental results from graph execution.
            The exact type and frequency of yields depends on node implementations
            and workflow configuration.

        Raises:
            GraphRecursionError: If execution exceeds recursion limit.
            ValueError: If input_data is invalid for new execution.
            Various exceptions: Depending on node execution failures.

        Example:
            ```python
            async for chunk in handler.stream(input_data, config, state):
                print(chunk)
            ```
        """
        logger.info(
            "Starting asynchronous graph execution with %d input keys, granularity=%s",
            len(input_data) if input_data else 0,
            response_granularity,
        )
        config = config or {}
        input_data = input_data or {}

        start_time = time.time()

        # Load or initialize state
        logger.debug("Loading or creating state from input data")
        new_state = await load_or_create_state(
            input_data,
            config,
            default_state,
        )
        state: StateT = new_state  # type: ignore[assignment]
        logger.debug(
            "State loaded: interrupted=%s, current_node=%s, step=%d",
            state.is_interrupted(),
            state.execution_meta.current_node,
            state.execution_meta.step,
        )

        cfg = config.copy()
        if "user" in cfg:
            # This will be available when you are calling
            # vi pyagenity api
            del cfg["user"]

        event = EventModel.default(
            config,
            data={"state": state},
            content_type=[ContentType.STATE],
            extra={
                "is_interrupted": state.is_interrupted(),
                "current_node": state.execution_meta.current_node,
                "step": state.execution_meta.step,
                "config": cfg,
                "response_granularity": response_granularity.value,
            },
        )

        # Publish graph initialization event
        publish_event(event)

        # Check if this is a resume case
        config = await self._check_interrupted(state, input_data, config)

        # Now start Execution
        # Execute graph
        logger.debug("Beginning graph execution")
        result = self._execute_graph(state, input_data, config)
        async for chunk in result:
            yield chunk

        # Publish graph completion event
        time_taken = time.time() - start_time
        logger.info("Graph execution finished in %.2f seconds", time_taken)

        event.event_type = EventType.END
        event.metadata.update(
            {
                "time_taken": time_taken,
                "state": state.model_dump(),
                "step": state.execution_meta.step,
                "current_node": state.execution_meta.current_node,
                "is_interrupted": state.is_interrupted(),
                "total_messages": len(state.context) if state.context else 0,
            }
        )
        publish_event(event)
Attributes
edges instance-attribute
edges = edges
interrupt_after instance-attribute
interrupt_after = interrupt_after or []
interrupt_before instance-attribute
interrupt_before = interrupt_before or []
nodes instance-attribute
nodes = nodes
Functions
__init__
__init__(nodes, edges, interrupt_before=None, interrupt_after=None)
Source code in pyagenity/graph/utils/stream_handler.py
71
72
73
74
75
76
77
78
79
80
81
82
83
@inject
def __init__(
    self,
    nodes: dict[str, Node],
    edges: list[Edge],
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
):
    self.nodes: dict[str, Node] = nodes
    self.edges: list[Edge] = edges
    self.interrupt_before = interrupt_before or []
    self.interrupt_after = interrupt_after or []
    self._set_interrupts(interrupt_before, interrupt_after)
stream async
stream(input_data, config, default_state, response_granularity=ResponseGranularity.LOW)

Execute the graph asynchronously with streaming output.

Runs the graph workflow from start to finish, yielding incremental results as they become available. Automatically detects whether to start a fresh execution or resume from an interrupted state, supporting pause/resume and checkpointing.

Parameters:

Name Type Description Default
input_data dict[str, Any]

Input dictionary for graph execution. For new executions, should contain 'messages' key with initial messages. For resumed executions, can contain additional data to merge.

required
config dict[str, Any]

Configuration dictionary containing execution settings and context.

required
default_state StateT

Initial or template AgentState for workflow execution.

required
response_granularity ResponseGranularity

Level of detail in the response (LOW, PARTIAL, FULL).

LOW

Yields:

Type Description
AsyncGenerator[Message]

Message objects representing incremental results from graph execution.

AsyncGenerator[Message]

The exact type and frequency of yields depends on node implementations

AsyncGenerator[Message]

and workflow configuration.

Raises:

Type Description
GraphRecursionError

If execution exceeds recursion limit.

ValueError

If input_data is invalid for new execution.

Various exceptions

Depending on node execution failures.

Example
async for chunk in handler.stream(input_data, config, state):
    print(chunk)
Source code in pyagenity/graph/utils/stream_handler.py
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
async def stream(
    self,
    input_data: dict[str, Any],
    config: dict[str, Any],
    default_state: StateT,
    response_granularity: ResponseGranularity = ResponseGranularity.LOW,
) -> AsyncGenerator[Message]:
    """Execute the graph asynchronously with streaming output.

    Runs the graph workflow from start to finish, yielding incremental results
    as they become available. Automatically detects whether to start a fresh
    execution or resume from an interrupted state, supporting pause/resume
    and checkpointing.

    Args:
        input_data: Input dictionary for graph execution. For new executions,
            should contain 'messages' key with initial messages. For resumed
            executions, can contain additional data to merge.
        config: Configuration dictionary containing execution settings and context.
        default_state: Initial or template AgentState for workflow execution.
        response_granularity: Level of detail in the response (LOW, PARTIAL, FULL).

    Yields:
        Message objects representing incremental results from graph execution.
        The exact type and frequency of yields depends on node implementations
        and workflow configuration.

    Raises:
        GraphRecursionError: If execution exceeds recursion limit.
        ValueError: If input_data is invalid for new execution.
        Various exceptions: Depending on node execution failures.

    Example:
        ```python
        async for chunk in handler.stream(input_data, config, state):
            print(chunk)
        ```
    """
    logger.info(
        "Starting asynchronous graph execution with %d input keys, granularity=%s",
        len(input_data) if input_data else 0,
        response_granularity,
    )
    config = config or {}
    input_data = input_data or {}

    start_time = time.time()

    # Load or initialize state
    logger.debug("Loading or creating state from input data")
    new_state = await load_or_create_state(
        input_data,
        config,
        default_state,
    )
    state: StateT = new_state  # type: ignore[assignment]
    logger.debug(
        "State loaded: interrupted=%s, current_node=%s, step=%d",
        state.is_interrupted(),
        state.execution_meta.current_node,
        state.execution_meta.step,
    )

    cfg = config.copy()
    if "user" in cfg:
        # This will be available when you are calling
        # vi pyagenity api
        del cfg["user"]

    event = EventModel.default(
        config,
        data={"state": state},
        content_type=[ContentType.STATE],
        extra={
            "is_interrupted": state.is_interrupted(),
            "current_node": state.execution_meta.current_node,
            "step": state.execution_meta.step,
            "config": cfg,
            "response_granularity": response_granularity.value,
        },
    )

    # Publish graph initialization event
    publish_event(event)

    # Check if this is a resume case
    config = await self._check_interrupted(state, input_data, config)

    # Now start Execution
    # Execute graph
    logger.debug("Beginning graph execution")
    result = self._execute_graph(state, input_data, config)
    async for chunk in result:
        yield chunk

    # Publish graph completion event
    time_taken = time.time() - start_time
    logger.info("Graph execution finished in %.2f seconds", time_taken)

    event.event_type = EventType.END
    event.metadata.update(
        {
            "time_taken": time_taken,
            "state": state.model_dump(),
            "step": state.execution_meta.step,
            "current_node": state.execution_meta.current_node,
            "is_interrupted": state.is_interrupted(),
            "total_messages": len(state.context) if state.context else 0,
        }
    )
    publish_event(event)
Functions
stream_node_handler

Streaming node handler for PyAgenity graph workflows.

This module provides the StreamNodeHandler class, which manages the execution of graph nodes that support streaming output. It handles both regular function nodes and ToolNode instances, enabling incremental result processing, dependency injection, callback management, and event publishing.

StreamNodeHandler is a key component for enabling real-time, chunked, or incremental responses in agent workflows, supporting both synchronous and asynchronous execution patterns.

Classes:

Name Description
StreamNodeHandler

Handles streaming execution for graph nodes in PyAgenity workflows.

Attributes:

Name Type Description
logger
Attributes
logger module-attribute
logger = getLogger(__name__)
Classes
StreamNodeHandler

Bases: BaseLoggingMixin

Handles streaming execution for graph nodes in PyAgenity workflows.

StreamNodeHandler manages the execution of nodes that can produce streaming output, including both regular function nodes and ToolNode instances. It supports dependency injection, callback management, event publishing, and incremental result processing.

Attributes:

Name Type Description
name

Unique identifier for the node within the graph.

func

The function or ToolNode to execute. Determines streaming behavior.

Example
handler = StreamNodeHandler("process", process_function)
async for chunk in handler.stream(config, state):
    print(chunk)

Methods:

Name Description
__init__

Initialize a new StreamNodeHandler instance.

stream

Execute the node function with streaming output and callback support.

Source code in pyagenity/graph/utils/stream_node_handler.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
class StreamNodeHandler(BaseLoggingMixin):
    """Handles streaming execution for graph nodes in PyAgenity workflows.

    StreamNodeHandler manages the execution of nodes that can produce streaming output,
    including both regular function nodes and ToolNode instances. It supports dependency
    injection, callback management, event publishing, and incremental result processing.

    Attributes:
        name: Unique identifier for the node within the graph.
        func: The function or ToolNode to execute. Determines streaming behavior.

    Example:
        ```python
        handler = StreamNodeHandler("process", process_function)
        async for chunk in handler.stream(config, state):
            print(chunk)
        ```
    """

    def __init__(
        self,
        name: str,
        func: Union[Callable, "ToolNode"],
    ):
        """Initialize a new StreamNodeHandler instance.

        Args:
            name: Unique identifier for the node within the graph.
            func: The function or ToolNode to execute. Determines streaming behavior.
        """
        self.name = name
        self.func = func

    async def _handle_single_tool(
        self,
        tool_call: dict[str, Any],
        state: AgentState,
        config: dict[str, Any],
    ) -> AsyncIterable[Message]:
        function_name = tool_call.get("function", {}).get("name", "")
        function_args: dict = json.loads(tool_call.get("function", {}).get("arguments", "{}"))
        tool_call_id = tool_call.get("id", "")

        logger.info(
            "Node '%s' executing tool '%s' with %d arguments",
            self.name,
            function_name,
            len(function_args),
        )
        logger.debug("Tool arguments: %s", function_args)

        # Execute the tool function with injectable parameters
        tool_result_gen = self.func.stream(  # type: ignore
            function_name,  # type: ignore
            function_args,
            tool_call_id=tool_call_id,
            state=state,
            config=config,
        )
        logger.debug("Node '%s' tool execution completed successfully", self.name)

        async for result in tool_result_gen:
            if isinstance(result, Message):
                yield result

    async def _call_tools(
        self,
        last_message: Message,
        state: "AgentState",
        config: dict[str, Any],
    ) -> AsyncIterable[Message]:
        logger.debug("Node '%s' calling tools from message", self.name)
        if (
            hasattr(last_message, "tools_calls")
            and last_message.tools_calls
            and len(last_message.tools_calls) > 0
        ):
            # Execute tool calls
            for tool_call in last_message.tools_calls:
                result_gen = self._handle_single_tool(
                    tool_call,
                    state,
                    config,
                )
                async for result in result_gen:
                    if isinstance(result, Message):
                        yield result
        else:
            # No tool calls to execute, return available tools
            logger.exception("Node '%s': No tool calls to execute", self.name)
            raise NodeError("No tool calls to execute")

    def _prepare_input_data(
        self,
        state: "AgentState",
        config: dict[str, Any],
    ) -> dict:
        sig = inspect.signature(self.func)  # type: ignore Tool node won't come here
        input_data = {}
        default_data = {
            "state": state,
            "config": config,
        }

        # # Get injectable parameters to determine which ones to exclude from manual passing
        # # Prepare function arguments (excluding injectable parameters)
        for param_name, param in sig.parameters.items():
            # Skip *args/**kwargs
            if param.kind in (
                inspect.Parameter.VAR_POSITIONAL,
                inspect.Parameter.VAR_KEYWORD,
            ):
                continue

            # check its state, config
            if param_name in ["state", "config"]:
                input_data[param_name] = default_data[param_name]
            # Include regular function arguments
            elif param.default is inspect.Parameter.empty:
                raise TypeError(
                    f"Missing required parameter '{param_name}' for function '{self.func}'"
                )

        return input_data

    async def _call_normal_node(  # noqa: PLR0912, PLR0915
        self,
        state: "AgentState",
        config: dict[str, Any],
        callback_mgr: CallbackManager,
    ) -> AsyncIterable[dict[str, Any] | Message]:
        logger.debug("Node '%s' calling normal function", self.name)
        result: dict[str, Any] | Message = {}

        logger.debug("Node '%s' is a regular function, executing with callbacks", self.name)
        # This is a regular function - likely AI function
        # Create callback context for AI invocation
        context = CallbackContext(
            invocation_type=InvocationType.AI,
            node_name=self.name,
            function_name=getattr(self.func, "__name__", str(self.func)),
            metadata={"config": config},
        )

        # Execute before_invoke callbacks
        input_data = self._prepare_input_data(
            state,
            config,
        )

        last_message = state.context[-1] if state.context and len(state.context) > 0 else None

        event = EventModel.default(
            config,
            data={"state": state.model_dump()},
            event=Event.NODE_EXECUTION,
            content_type=[ContentType.STATE],
            node_name=self.name,
            extra={
                "node": self.name,
                "function_name": getattr(self.func, "__name__", str(self.func)),
                "last_message": last_message.model_dump() if last_message else None,
            },
        )
        publish_event(event)

        try:
            logger.debug("Node '%s' executing before_invoke callbacks", self.name)
            # Execute before_invoke callbacks
            input_data = await callback_mgr.execute_before_invoke(context, input_data)
            logger.debug("Node '%s' executing function", self.name)
            event.event_type = EventType.PROGRESS
            event.content = "Function execution started"
            publish_event(event)

            # Execute the actual function
            result = await call_sync_or_async(
                self.func,  # type: ignore
                **input_data,
            )
            logger.debug("Node '%s' function execution completed", self.name)

            logger.debug("Node '%s' executing after_invoke callbacks", self.name)
            # Execute after_invoke callbacks
            result = await callback_mgr.execute_after_invoke(context, input_data, result)

            # Now lets convert the response here only, upstream will be easy to handle
            ##############################################################################
            ################### Logics for streaming ##########################
            ##############################################################################
            """
            Check user sending command or not
            if command then we will check its streaming or not
            if streaming then we will yield from converter stream
            if not streaming then we will convert it and yield end event
            if its not command then we will check its streaming or not
            if streaming then we will yield from converter stream
            if not streaming then we will convert it and yield end event
            """
            # first check its sync and not streaming
            next_node = None
            final_result = result
            # if type of command then we will update it
            if isinstance(result, Command):
                # now check the updated
                if result.update:
                    final_result = result.update

                if result.state:
                    state = result.state
                    for msg in state.context:
                        yield msg

                next_node = result.goto

            messages = []
            if check_non_streaming(final_result):
                new_state, messages, next_node = await process_node_result(
                    final_result,
                    state,
                    messages,
                )
                event.data["state"] = new_state.model_dump()
                event.event_type = EventType.END
                event.data["messages"] = [m.model_dump() for m in messages] if messages else []
                event.data["next_node"] = next_node
                publish_event(event)
                for m in messages:
                    yield m

                yield {
                    "is_non_streaming": True,
                    "state": new_state,
                    "messages": messages,
                    "next_node": next_node,
                }
                return  # done

            # If the result is a ConverterCall with stream=True, use the converter
            if isinstance(result, ModelResponseConverter) and result.response:
                stream_gen = result.stream(
                    config,
                    node_name=self.name,
                    meta={
                        "function_name": getattr(self.func, "__name__", str(self.func)),
                    },
                )
                # this will return event_model or message
                async for item in stream_gen:
                    if isinstance(item, Message) and not item.delta:
                        messages.append(item)
                    yield item
            # Things are done, so publish event and yield final response
            event.event_type = EventType.END
            if messages:
                final_msg = messages[-1]
                event.data["message"] = final_msg.model_dump()
                # Populate simple content and structured blocks when available
                event.content = (
                    final_msg.text() if isinstance(final_msg.content, list) else final_msg.content
                )
                if isinstance(final_msg.content, list):
                    event.content_blocks = final_msg.content
            else:
                event.data["message"] = None
                event.content = ""
                event.content_blocks = None
            event.content_type = [ContentType.MESSAGE, ContentType.STATE]
            publish_event(event)
            # if user use command and its streaming in that case we need to handle next node also
            yield {
                "is_non_streaming": False,
                "messages": messages,
                "next_node": next_node,
            }

        except Exception as e:
            logger.warning(
                "Node '%s' execution failed, executing error callbacks: %s", self.name, e
            )
            # Execute error callbacks
            recovery_result = await callback_mgr.execute_on_error(context, input_data, e)

            if isinstance(recovery_result, Message):
                logger.info(
                    "Node '%s' recovered from error using callback result",
                    self.name,
                )
                # Use recovery result instead of raising the error
                event.event_type = EventType.END
                event.content = "Function execution recovered from error"
                event.data["message"] = recovery_result.model_dump()
                event.content_type = [ContentType.MESSAGE, ContentType.STATE]
                publish_event(event)

                yield recovery_result
            else:
                # Re-raise the original error
                logger.error("Node '%s' could not recover from error", self.name)
                event.event_type = EventType.ERROR
                event.content = f"Function execution failed: {e}"
                event.data["error"] = str(e)
                event.content_type = [ContentType.ERROR, ContentType.STATE]
                publish_event(event)
                raise

    async def stream(
        self,
        config: dict[str, Any],
        state: AgentState,
        callback_mgr: CallbackManager = Inject[CallbackManager],
    ) -> AsyncGenerator[dict[str, Any] | Message]:
        """Execute the node function with streaming output and callback support.

        Handles both ToolNode and regular function nodes, yielding incremental results
        as they become available. Supports dependency injection, callback management,
        and event publishing for monitoring and debugging.

        Args:
            config: Configuration dictionary containing execution context and settings.
            state: Current AgentState providing workflow context and shared state.
            callback_mgr: Callback manager for pre/post execution hook handling.

        Yields:
            Dictionary objects or Message instances representing incremental outputs
            from the node function. The exact type and frequency of yields depends on
            the node function's streaming implementation.

        Raises:
            NodeError: If node execution fails or encounters an error.

        Example:
            ```python
            async for chunk in handler.stream(config, state):
                print(chunk)
            ```
        """
        logger.info("Executing node '%s'", self.name)
        logger.debug(
            "Node '%s' execution with state context size=%d, config keys=%s",
            self.name,
            len(state.context) if state.context else 0,
            list(config.keys()) if config else [],
        )

        # In this function publishing events not required
        # If its tool node, its already handled there, from start to end
        # In this class we need to handle normal function calls only
        # We will yield events from here only for normal function calls
        # ToolNode will yield events from its own stream method

        try:
            if isinstance(self.func, ToolNode):
                logger.debug("Node '%s' is a ToolNode, executing tool calls", self.name)
                # This is tool execution - handled separately in ToolNode
                if state.context and len(state.context) > 0:
                    last_message = state.context[-1]
                    logger.debug("Node '%s' processing tool calls from last message", self.name)
                    result = self._call_tools(
                        last_message,
                        state,
                        config,
                    )
                    async for item in result:
                        yield item
                    # Check if last message has tool calls to execute
                else:
                    # No context, return available tools
                    error_msg = "No context available for tool execution"
                    logger.error("Node '%s': %s", self.name, error_msg)
                    raise NodeError(error_msg)

            else:
                result = self._call_normal_node(
                    state,
                    config,
                    callback_mgr,
                )
                async for item in result:
                    yield item

            logger.info("Node '%s' execution completed successfully", self.name)
        except Exception as e:
            # This is the final catch-all for node execution errors
            logger.exception("Node '%s' execution failed: %s", self.name, e)
            raise NodeError(f"Error in node '{self.name}': {e!s}") from e
Attributes
func instance-attribute
func = func
name instance-attribute
name = name
Functions
__init__
__init__(name, func)

Initialize a new StreamNodeHandler instance.

Parameters:

Name Type Description Default
name str

Unique identifier for the node within the graph.

required
func Union[Callable, ToolNode]

The function or ToolNode to execute. Determines streaming behavior.

required
Source code in pyagenity/graph/utils/stream_node_handler.py
62
63
64
65
66
67
68
69
70
71
72
73
74
def __init__(
    self,
    name: str,
    func: Union[Callable, "ToolNode"],
):
    """Initialize a new StreamNodeHandler instance.

    Args:
        name: Unique identifier for the node within the graph.
        func: The function or ToolNode to execute. Determines streaming behavior.
    """
    self.name = name
    self.func = func
stream async
stream(config, state, callback_mgr=Inject[CallbackManager])

Execute the node function with streaming output and callback support.

Handles both ToolNode and regular function nodes, yielding incremental results as they become available. Supports dependency injection, callback management, and event publishing for monitoring and debugging.

Parameters:

Name Type Description Default
config dict[str, Any]

Configuration dictionary containing execution context and settings.

required
state AgentState

Current AgentState providing workflow context and shared state.

required
callback_mgr CallbackManager

Callback manager for pre/post execution hook handling.

Inject[CallbackManager]

Yields:

Type Description
AsyncGenerator[dict[str, Any] | Message]

Dictionary objects or Message instances representing incremental outputs

AsyncGenerator[dict[str, Any] | Message]

from the node function. The exact type and frequency of yields depends on

AsyncGenerator[dict[str, Any] | Message]

the node function's streaming implementation.

Raises:

Type Description
NodeError

If node execution fails or encounters an error.

Example
async for chunk in handler.stream(config, state):
    print(chunk)
Source code in pyagenity/graph/utils/stream_node_handler.py
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
async def stream(
    self,
    config: dict[str, Any],
    state: AgentState,
    callback_mgr: CallbackManager = Inject[CallbackManager],
) -> AsyncGenerator[dict[str, Any] | Message]:
    """Execute the node function with streaming output and callback support.

    Handles both ToolNode and regular function nodes, yielding incremental results
    as they become available. Supports dependency injection, callback management,
    and event publishing for monitoring and debugging.

    Args:
        config: Configuration dictionary containing execution context and settings.
        state: Current AgentState providing workflow context and shared state.
        callback_mgr: Callback manager for pre/post execution hook handling.

    Yields:
        Dictionary objects or Message instances representing incremental outputs
        from the node function. The exact type and frequency of yields depends on
        the node function's streaming implementation.

    Raises:
        NodeError: If node execution fails or encounters an error.

    Example:
        ```python
        async for chunk in handler.stream(config, state):
            print(chunk)
        ```
    """
    logger.info("Executing node '%s'", self.name)
    logger.debug(
        "Node '%s' execution with state context size=%d, config keys=%s",
        self.name,
        len(state.context) if state.context else 0,
        list(config.keys()) if config else [],
    )

    # In this function publishing events not required
    # If its tool node, its already handled there, from start to end
    # In this class we need to handle normal function calls only
    # We will yield events from here only for normal function calls
    # ToolNode will yield events from its own stream method

    try:
        if isinstance(self.func, ToolNode):
            logger.debug("Node '%s' is a ToolNode, executing tool calls", self.name)
            # This is tool execution - handled separately in ToolNode
            if state.context and len(state.context) > 0:
                last_message = state.context[-1]
                logger.debug("Node '%s' processing tool calls from last message", self.name)
                result = self._call_tools(
                    last_message,
                    state,
                    config,
                )
                async for item in result:
                    yield item
                # Check if last message has tool calls to execute
            else:
                # No context, return available tools
                error_msg = "No context available for tool execution"
                logger.error("Node '%s': %s", self.name, error_msg)
                raise NodeError(error_msg)

        else:
            result = self._call_normal_node(
                state,
                config,
                callback_mgr,
            )
            async for item in result:
                yield item

        logger.info("Node '%s' execution completed successfully", self.name)
    except Exception as e:
        # This is the final catch-all for node execution errors
        logger.exception("Node '%s' execution failed: %s", self.name, e)
        raise NodeError(f"Error in node '{self.name}': {e!s}") from e
Functions
stream_utils

Streaming utility functions for PyAgenity graph workflows.

This module provides helper functions for determining whether a result from a node or tool execution should be treated as non-streaming (i.e., a complete result) or processed incrementally as a stream. These utilities are used throughout the graph execution engine to support both synchronous and streaming workflows.

Functions:

Name Description
check_non_streaming

Determine if a result should be treated as non-streaming.

Classes Functions
check_non_streaming
check_non_streaming(result)

Determine if a result should be treated as non-streaming.

Checks whether the given result is a complete, non-streaming output (such as a list, dict, string, Message, or AgentState) or if it should be processed incrementally as a stream.

Parameters:

Name Type Description Default
result

The result object returned from a node or tool execution. Can be any type.

required

Returns:

Name Type Description
bool bool

True if the result is non-streaming and should be processed as a complete output;

bool

False if the result should be handled as a stream.

Example

check_non_streaming([Message.text_message("done")]) True check_non_streaming(Message.text_message("done")) True check_non_streaming({"choices": [...]}) True check_non_streaming("some text") True

Source code in pyagenity/graph/utils/stream_utils.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def check_non_streaming(result) -> bool:
    """Determine if a result should be treated as non-streaming.

    Checks whether the given result is a complete, non-streaming output (such as a list,
    dict, string, Message, or AgentState) or if it should be processed incrementally as a stream.

    Args:
        result: The result object returned from a node or tool execution. Can be any type.

    Returns:
        bool: True if the result is non-streaming and should be processed as a complete output;
        False if the result should be handled as a stream.

    Example:
        >>> check_non_streaming([Message.text_message("done")])
        True
        >>> check_non_streaming(Message.text_message("done"))
        True
        >>> check_non_streaming({"choices": [...]})
        True
        >>> check_non_streaming("some text")
        True
    """
    if isinstance(result, list | dict | str):
        return True

    if isinstance(result, Message):
        return True

    if isinstance(result, AgentState):
        return True

    if isinstance(result, dict) and "choices" in result:
        return True

    return bool(isinstance(result, Message))
utils

Core utility functions for graph execution and state management.

This module provides essential utilities for PyAgenity graph execution, including state management, message processing, response formatting, and execution flow control. These functions handle the low-level operations that support graph workflow execution.

The utilities in this module are designed to work with PyAgenity's dependency injection system and provide consistent interfaces for common operations across different execution contexts.

Key functionality areas: - State loading, creation, and synchronization - Message processing and deduplication - Response formatting based on granularity levels - Node execution result processing - Interrupt handling and execution flow control

Functions:

Name Description
call_realtime_sync

Call the realtime state sync hook if provided.

check_and_handle_interrupt

Check for interrupts and save state if needed. Returns True if interrupted.

get_next_node

Get the next node to execute based on edges.

load_or_create_state

Load existing state from checkpointer or create new state.

parse_response

Parse and format execution response based on specified granularity level.

process_node_result

Processes the result from a node execution, updating the agent state, message list,

reload_state

Load existing state from checkpointer or create new state.

sync_data

Sync the current state and messages to the checkpointer.

Attributes:

Name Type Description
StateT
logger
Attributes
StateT module-attribute
StateT = TypeVar('StateT', bound=AgentState)
logger module-attribute
logger = getLogger(__name__)
Classes Functions
call_realtime_sync async
call_realtime_sync(state, config, checkpointer=Inject[BaseCheckpointer])

Call the realtime state sync hook if provided.

Source code in pyagenity/graph/utils/utils.py
460
461
462
463
464
465
466
467
468
469
async def call_realtime_sync(
    state: AgentState,
    config: dict[str, Any],
    checkpointer: BaseCheckpointer = Inject[BaseCheckpointer],  # will be auto-injected
) -> None:
    """Call the realtime state sync hook if provided."""
    if checkpointer:
        logger.debug("Calling realtime state sync hook")
        # await call_sync_or_async(checkpointer.a, config, state)
        await checkpointer.aput_state_cache(config, state)
check_and_handle_interrupt async
check_and_handle_interrupt(interrupt_before, interrupt_after, current_node, interrupt_type, state, config, _sync_data)

Check for interrupts and save state if needed. Returns True if interrupted.

Source code in pyagenity/graph/utils/utils.py
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
async def check_and_handle_interrupt(
    interrupt_before: list[str],
    interrupt_after: list[str],
    current_node: str,
    interrupt_type: str,
    state: AgentState,
    config: dict[str, Any],
    _sync_data: Callable,
) -> bool:
    """Check for interrupts and save state if needed. Returns True if interrupted."""
    interrupt_nodes = interrupt_before if interrupt_type == "before" else interrupt_after

    if current_node in interrupt_nodes:
        status = (
            ExecutionStatus.INTERRUPTED_BEFORE
            if interrupt_type == "before"
            else ExecutionStatus.INTERRUPTED_AFTER
        )
        state.set_interrupt(
            current_node,
            f"interrupt_{interrupt_type}: {current_node}",
            status,
        )
        # Save state and interrupt
        await _sync_data(state, config, [])
        logger.debug("Node '%s' interrupted", current_node)
        return True

    logger.debug(
        "No interrupts found for node '%s', continuing execution",
        current_node,
    )
    return False
get_next_node
get_next_node(current_node, state, edges)

Get the next node to execute based on edges.

Source code in pyagenity/graph/utils/utils.py
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
def get_next_node(
    current_node: str,
    state: AgentState,
    edges: list,
) -> str:
    """Get the next node to execute based on edges."""
    # Find outgoing edges from current node
    outgoing_edges = [e for e in edges if e.from_node == current_node]

    if not outgoing_edges:
        logger.debug("No outgoing edges from node '%s', ending execution", current_node)
        return END

    # Handle conditional edges
    for edge in outgoing_edges:
        if edge.condition:
            try:
                condition_result = edge.condition(state)
                if hasattr(edge, "condition_result") and edge.condition_result is not None:
                    # Mapped conditional edge
                    if condition_result == edge.condition_result:
                        return edge.to_node
                elif isinstance(condition_result, str):
                    return condition_result
                elif condition_result:
                    return edge.to_node
            except Exception:
                logger.exception("Error evaluating condition for edge: %s", edge)
                continue

    # Return first static edge if no conditions matched
    static_edges = [e for e in outgoing_edges if not e.condition]
    if static_edges:
        return static_edges[0].to_node

    logger.debug("No valid edges found from node '%s', ending execution", current_node)
    return END
load_or_create_state async
load_or_create_state(input_data, config, old_state, checkpointer=Inject[BaseCheckpointer])

Load existing state from checkpointer or create new state.

Attempts to fetch a realtime-synced state first, then falls back to the persistent checkpointer. If no existing state is found, creates a new state from the StateGraph's prototype state and merges any incoming messages. Supports partial state update via 'state' in input_data.

Source code in pyagenity/graph/utils/utils.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
async def load_or_create_state[StateT: AgentState](  # noqa: PLR0912, PLR0915
    input_data: dict[str, Any],
    config: dict[str, Any],
    old_state: StateT,
    checkpointer: BaseCheckpointer = Inject[BaseCheckpointer],  # will be auto-injected
) -> StateT:
    """Load existing state from checkpointer or create new state.

    Attempts to fetch a realtime-synced state first, then falls back to
    the persistent checkpointer. If no existing state is found, creates
    a new state from the `StateGraph`'s prototype state and merges any
    incoming messages. Supports partial state update via 'state' in input_data.
    """
    logger.debug("Loading or creating state with thread_id=%s", config.get("thread_id", "default"))

    # Try to load existing state if checkpointer is available
    if checkpointer:
        logger.debug("Attempting to load existing state from checkpointer")
        # first check realtime-synced state
        existing_state: StateT | None = await checkpointer.aget_state_cache(config)
        if not existing_state:
            logger.debug("No synced state found, trying persistent checkpointer")
            # If no synced state, try to get from persistent checkpointer
            existing_state = await checkpointer.aget_state(config)

        if existing_state:
            logger.info(
                "Loaded existing state with %d context messages, current_node=%s, step=%d",
                len(existing_state.context) if existing_state.context else 0,
                existing_state.execution_meta.current_node,
                existing_state.execution_meta.step,
            )
            # Normalize legacy node names (backward compatibility)
            # Some older runs may have persisted 'start'/'end' instead of '__start__'/'__end__'
            if existing_state.execution_meta.current_node == "start":
                existing_state.execution_meta.current_node = START
                logger.debug("Normalized legacy current_node 'start' to '%s'", START)
            elif existing_state.execution_meta.current_node == "end":
                existing_state.execution_meta.current_node = END
                logger.debug("Normalized legacy current_node 'end' to '%s'", END)
            elif existing_state.execution_meta.current_node == "__start__":
                existing_state.execution_meta.current_node = START
                logger.debug("Normalized legacy current_node '__start__' to '%s'", START)
            elif existing_state.execution_meta.current_node == "__end__":
                existing_state.execution_meta.current_node = END
                logger.debug("Normalized legacy current_node '__end__' to '%s'", END)
            # Merge new messages with existing context
            new_messages = input_data.get("messages", [])
            if new_messages:
                logger.debug("Merging %d new messages with existing context", len(new_messages))
                existing_state.context = add_messages(existing_state.context, new_messages)
            # Merge partial state fields if provided
            partial_state = input_data.get("state", {})
            if partial_state and isinstance(partial_state, dict):
                logger.debug("Merging partial state with %d fields", len(partial_state))
                _update_state_fields(existing_state, partial_state)
            # Update current node if available
            if "current_node" in partial_state and partial_state["current_node"] is not None:
                existing_state.set_current_node(partial_state["current_node"])
            return existing_state
    else:
        logger.debug("No checkpointer available, will create new state")

    # Create new state by deep copying the graph's prototype state
    logger.info("Creating new state from graph prototype")
    state = copy.deepcopy(old_state)

    # Ensure core AgentState fields are properly initialized
    if hasattr(state, "context") and not isinstance(state.context, list):
        state.context = []
        logger.debug("Initialized empty context list")
    if hasattr(state, "context_summary") and state.context_summary is None:
        state.context_summary = None
        logger.debug("Initialized context_summary as None")
    if hasattr(state, "execution_meta"):
        # Create a fresh execution metadata
        state.execution_meta = ExecMeta(current_node=START)
        logger.debug("Created fresh execution metadata starting at %s", START)

    # Set thread_id in execution metadata
    thread_id = config.get("thread_id", "default")
    state.execution_meta.thread_id = thread_id
    logger.debug("Set thread_id to %s", thread_id)

    # Merge new messages with context
    new_messages = input_data.get("messages", [])
    if new_messages:
        logger.debug("Adding %d new messages to fresh state", len(new_messages))
        state.context = add_messages(state.context, new_messages)
    # Merge partial state fields if provided
    partial_state = input_data.get("state", {})
    if partial_state and isinstance(partial_state, dict):
        logger.debug("Merging partial state with %d fields", len(partial_state))
        _update_state_fields(state, partial_state)

    logger.info(
        "Created new state with %d context messages", len(state.context) if state.context else 0
    )
    if "current_node" in partial_state and partial_state["current_node"] is not None:
        # Normalize legacy values if provided in partial state
        next_node = partial_state["current_node"]
        if next_node == "__start__":
            next_node = START
        elif next_node == "__end__":
            next_node = END
        state.set_current_node(next_node)
    return state  # type: ignore[return-value]
parse_response async
parse_response(state, messages, response_granularity=ResponseGranularity.LOW)

Parse and format execution response based on specified granularity level.

Formats the final response from graph execution according to the requested granularity level, allowing clients to receive different levels of detail depending on their needs.

Parameters:

Name Type Description Default
state AgentState

The final agent state after graph execution.

required
messages list[Message]

List of messages generated during execution.

required
response_granularity ResponseGranularity

Level of detail to include in the response: - FULL: Returns complete state object and all messages - PARTIAL: Returns context, summary, and messages - LOW: Returns only the messages (default)

LOW

Returns:

Type Description
dict[str, Any]

Dictionary containing the formatted response with keys depending on

dict[str, Any]

granularity level. Always includes 'messages' key with execution results.

Example
# LOW granularity (default)
response = await parse_response(state, messages)
# Returns: {"messages": [Message(...), ...]}

# FULL granularity
response = await parse_response(state, messages, ResponseGranularity.FULL)
# Returns: {"state": AgentState(...), "messages": [Message(...), ...]}
Source code in pyagenity/graph/utils/utils.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
async def parse_response(
    state: AgentState,
    messages: list[Message],
    response_granularity: ResponseGranularity = ResponseGranularity.LOW,
) -> dict[str, Any]:
    """Parse and format execution response based on specified granularity level.

    Formats the final response from graph execution according to the requested
    granularity level, allowing clients to receive different levels of detail
    depending on their needs.

    Args:
        state: The final agent state after graph execution.
        messages: List of messages generated during execution.
        response_granularity: Level of detail to include in the response:
            - FULL: Returns complete state object and all messages
            - PARTIAL: Returns context, summary, and messages
            - LOW: Returns only the messages (default)

    Returns:
        Dictionary containing the formatted response with keys depending on
        granularity level. Always includes 'messages' key with execution results.

    Example:
        ```python
        # LOW granularity (default)
        response = await parse_response(state, messages)
        # Returns: {"messages": [Message(...), ...]}

        # FULL granularity
        response = await parse_response(state, messages, ResponseGranularity.FULL)
        # Returns: {"state": AgentState(...), "messages": [Message(...), ...]}
        ```
    """
    match response_granularity:
        case ResponseGranularity.FULL:
            # Return full state and messages
            return {"state": state, "messages": messages}
        case ResponseGranularity.PARTIAL:
            # Return state and summary of messages
            return {
                "context": state.context,
                "summary": state.context_summary,
                "message": messages,
            }
        case ResponseGranularity.LOW:
            # Return all messages from state context
            return {"messages": messages}

    return {"messages": messages}
process_node_result async
process_node_result(result, state, messages)

Processes the result from a node execution, updating the agent state, message list, and determining the next node.

Supports: - Handling results of type Command, AgentState, Message, list, str, dict, or other types. - Deduplicating messages by message_id. - Updating the agent state and its context with new messages. - Extracting navigation information (next node) from Command results.

Parameters:

Name Type Description Default
result Any

The output from a node execution. Can be a Command, AgentState, Message, list, str, dict, ModelResponse, or other types.

required
state StateT

The current agent state.

required
messages list[Message]

The list of messages accumulated so far.

required

Returns:

Type Description
tuple[StateT, list[Message], str | None]

tuple[StateT, list[Message], str | None]: - The updated agent state. - The updated list of messages (with new, unique messages added). - The identifier of the next node to execute, if specified; otherwise, None.

Source code in pyagenity/graph/utils/utils.py
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
async def process_node_result[StateT: AgentState](  # noqa: PLR0915
    result: Any,
    state: StateT,
    messages: list[Message],
) -> tuple[StateT, list[Message], str | None]:
    """
    Processes the result from a node execution, updating the agent state, message list,
    and determining the next node.

    Supports:
    - Handling results of type Command, AgentState, Message, list, str, dict,
            or other types.
        - Deduplicating messages by message_id.
        - Updating the agent state and its context with new messages.
        - Extracting navigation information (next node) from Command results.

    Args:
        result (Any): The output from a node execution. Can be a Command, AgentState, Message,
            list, str, dict, ModelResponse, or other types.
        state (StateT): The current agent state.
        messages (list[Message]): The list of messages accumulated so far.

    Returns:
        tuple[StateT, list[Message], str | None]:
            - The updated agent state.
            - The updated list of messages (with new, unique messages added).
            - The identifier of the next node to execute, if specified; otherwise, None.
    """
    next_node = None
    existing_ids = {msg.message_id for msg in messages}
    new_messages = []

    def add_unique_message(msg: Message) -> None:
        """Add message only if it doesn't already exist."""
        if msg.message_id not in existing_ids:
            new_messages.append(msg)
            existing_ids.add(msg.message_id)

    async def create_and_add_message(content: Any) -> Message:
        """Create message from content and add if unique."""
        if isinstance(content, Message):
            msg = content
        elif isinstance(content, ModelResponseConverter):
            msg = await content.invoke()
        elif isinstance(content, str):
            msg = Message.text_message(
                content,
                role="assistant",
            )

        else:
            err = f"""
            Unsupported content type for message: {type(content)}.
            Supported types are: AgentState, Message, ModelResponseConverter, Command, str,
            dict (OpenAI style/Native Message).
            """
            raise ValueError(err)

        add_unique_message(msg)
        return msg

    def handle_state_message(old_state: StateT, new_state: StateT) -> None:
        """Handle state messages by updating the context."""
        old_messages = {}
        if old_state.context:
            old_messages = {msg.message_id: msg for msg in old_state.context}

        if not new_state.context:
            return
        # now save all the new messages
        for msg in new_state.context:
            if msg.message_id in old_messages:
                continue
            # otherwise save it
            add_unique_message(msg)

    # Process different result types
    if isinstance(result, Command):
        # Handle state updates
        if result.update:
            if isinstance(result.update, AgentState):
                handle_state_message(state, result.update)  # type: ignore[assignment]
                state = result.update  # type: ignore[assignment]
            elif isinstance(result.update, list):
                for item in result.update:
                    await create_and_add_message(item)
            else:
                await create_and_add_message(result.update)

        # Handle navigation
        next_node = result.goto

    elif isinstance(result, AgentState):
        handle_state_message(state, result)  # type: ignore[assignment]
        state = result  # type: ignore[assignment]

    elif isinstance(result, Message):
        add_unique_message(result)

    elif isinstance(result, list):
        # Handle list of items (convert each to message)
        for item in result:
            await create_and_add_message(item)
    else:
        # Handle single items (str, dict, model_dump-capable, or other)
        await create_and_add_message(result)

    # Add new messages to the main list and state context
    if new_messages:
        messages.extend(new_messages)
        state.context = add_messages(state.context, new_messages)

    return state, messages, next_node
reload_state async
reload_state(config, old_state, checkpointer=Inject[BaseCheckpointer])

Load existing state from checkpointer or create new state.

Attempts to fetch a realtime-synced state first, then falls back to the persistent checkpointer. If no existing state is found, creates a new state from the StateGraph's prototype state and merges any incoming messages. Supports partial state update via 'state' in input_data.

Source code in pyagenity/graph/utils/utils.py
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
async def reload_state[StateT: AgentState](
    config: dict[str, Any],
    old_state: StateT,
    checkpointer: BaseCheckpointer = Inject[BaseCheckpointer],  # will be auto-injected
) -> StateT:
    """Load existing state from checkpointer or create new state.

    Attempts to fetch a realtime-synced state first, then falls back to
    the persistent checkpointer. If no existing state is found, creates
    a new state from the `StateGraph`'s prototype state and merges any
    incoming messages. Supports partial state update via 'state' in input_data.
    """
    logger.debug("Loading or creating state with thread_id=%s", config.get("thread_id", "default"))

    if not checkpointer:
        return old_state

    # first check realtime-synced state
    existing_state: AgentState | None = await checkpointer.aget_state_cache(config)
    if not existing_state:
        logger.debug("No synced state found, trying persistent checkpointer")
        # If no synced state, try to get from persistent checkpointer
        existing_state = await checkpointer.aget_state(config)

    if not existing_state:
        logger.warning("No existing state found to reload, returning old state")
        return old_state

    logger.info(
        "Loaded existing state with %d context messages, current_node=%s, step=%d",
        len(existing_state.context) if existing_state.context else 0,
        existing_state.execution_meta.current_node,
        existing_state.execution_meta.step,
    )
    # Normalize legacy node names (backward compatibility)
    # Some older runs may have persisted 'start'/'end' instead of '__start__'/'__end__'
    if existing_state.execution_meta.current_node == "start":
        existing_state.execution_meta.current_node = START
        logger.debug("Normalized legacy current_node 'start' to '%s'", START)
    elif existing_state.execution_meta.current_node == "end":
        existing_state.execution_meta.current_node = END
        logger.debug("Normalized legacy current_node 'end' to '%s'", END)
    elif existing_state.execution_meta.current_node == "__start__":
        existing_state.execution_meta.current_node = START
        logger.debug("Normalized legacy current_node '__start__' to '%s'", START)
    elif existing_state.execution_meta.current_node == "__end__":
        existing_state.execution_meta.current_node = END
        logger.debug("Normalized legacy current_node '__end__' to '%s'", END)
    return existing_state
sync_data async
sync_data(state, config, messages, trim=False, checkpointer=Inject[BaseCheckpointer], context_manager=Inject[BaseContextManager])

Sync the current state and messages to the checkpointer.

Source code in pyagenity/graph/utils/utils.py
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
async def sync_data(
    state: AgentState,
    config: dict[str, Any],
    messages: list[Message],
    trim: bool = False,
    checkpointer: BaseCheckpointer = Inject[BaseCheckpointer],  # will be auto-injected
    context_manager: BaseContextManager = Inject[BaseContextManager],  # will be auto-injected
) -> bool:
    """Sync the current state and messages to the checkpointer."""
    is_context_trimmed = False

    new_state = copy.deepcopy(state)
    # if context manager is available then utilize it
    if context_manager and trim:
        new_state = await context_manager.atrim_context(state)
        is_context_trimmed = True

    # first sync with realtime then main db
    await call_realtime_sync(state, config, checkpointer)
    logger.debug("Persisting state and %d messages to checkpointer", len(messages))

    if checkpointer:
        await checkpointer.aput_state(config, new_state)
        if messages:
            await checkpointer.aput_messages(config, messages)

    return is_context_trimmed

prebuilt

Modules:

Name Description
agent

Modules

agent

Modules:

Name Description
branch_join
deep_research
guarded
map_reduce
network
plan_act_reflect
rag
react
router
sequential
supervisor_team
swarm

Classes:

Name Description
BranchJoinAgent

Execute multiple branches then join.

DeepResearchAgent

Deep Research Agent: PLAN → RESEARCH → SYNTHESIZE → CRITIQUE loop.

GuardedAgent

Validate output and repair until valid or attempts exhausted.

MapReduceAgent

Map over items then reduce.

NetworkAgent

Network pattern: define arbitrary node set and routing policies.

PlanActReflectAgent

Plan -> Act -> Reflect looping agent.

RAGAgent

Simple RAG: retrieve -> synthesize; optional follow-up.

ReactAgent
RouterAgent

A configurable router-style agent.

SequentialAgent

A simple sequential agent that executes a fixed pipeline of nodes.

SupervisorTeamAgent

Supervisor routes tasks to worker nodes and aggregates results.

SwarmAgent

Swarm pattern: dispatch to many workers, collect, then reach consensus.

Attributes
__all__ module-attribute
__all__ = ['BranchJoinAgent', 'DeepResearchAgent', 'GuardedAgent', 'MapReduceAgent', 'NetworkAgent', 'PlanActReflectAgent', 'RAGAgent', 'ReactAgent', 'RouterAgent', 'SequentialAgent', 'SupervisorTeamAgent', 'SwarmAgent']
Classes
BranchJoinAgent

Execute multiple branches then join.

Note: This prebuilt models branches sequentially (not true parallel execution). For each provided branch node, we add edges branch_i -> JOIN. The JOIN node decides whether more branches remain or END. A more advanced version could use BackgroundTaskManager for concurrency.

Methods:

Name Description
__init__
compile
Source code in pyagenity/prebuilt/agent/branch_join.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
class BranchJoinAgent[StateT: AgentState]:
    """Execute multiple branches then join.

    Note: This prebuilt models branches sequentially (not true parallel execution).
    For each provided branch node, we add edges branch_i -> JOIN. The JOIN node
    decides whether more branches remain or END. A more advanced version could
    use BackgroundTaskManager for concurrency.
    """

    def __init__(
        self,
        state: StateT | None = None,
        context_manager: BaseContextManager[StateT] | None = None,
        publisher: BasePublisher | None = None,
        id_generator: BaseIDGenerator = DefaultIDGenerator(),
        container: InjectQ | None = None,
    ):
        self._graph = StateGraph[StateT](
            state=state,
            context_manager=context_manager,
            publisher=publisher,
            id_generator=id_generator,
            container=container,
        )

    def compile(
        self,
        branches: dict[str, Callable | tuple[Callable, str]],
        join_node: Callable | tuple[Callable, str],
        next_branch_condition: Callable | None = None,
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> CompiledGraph:
        if not branches:
            raise ValueError("branches must be a non-empty dict of name -> callable/tuple")

        # Add branch nodes
        branch_names = []
        for key, fn in branches.items():
            if isinstance(fn, tuple):
                branch_func, branch_name = fn
                if not callable(branch_func):
                    raise ValueError(f"Branch '{key}'[0] must be callable")
            else:
                branch_func = fn
                branch_name = key
                if not callable(branch_func):
                    raise ValueError(f"Branch '{key}' must be callable")
            self._graph.add_node(branch_name, branch_func)
            branch_names.append(branch_name)

        # Handle join_node
        if isinstance(join_node, tuple):
            join_func, join_name = join_node
            if not callable(join_func):
                raise ValueError("join_node[0] must be callable")
        else:
            join_func = join_node
            join_name = "JOIN"
            if not callable(join_func):
                raise ValueError("join_node must be callable")
        self._graph.add_node(join_name, join_func)

        # Wire branches to JOIN
        for name in branch_names:
            self._graph.add_edge(name, join_name)

        # Entry: first branch
        first = branch_names[0]
        self._graph.set_entry_point(first)

        # Decide next branch or END after join
        if next_branch_condition is None:
            # default: END after join
            def _cond(_: AgentState) -> str:
                return END

            next_branch_condition = _cond

        # next_branch_condition returns a branch name or END
        path_map = {k: k for k in branch_names}
        path_map[END] = END
        self._graph.add_conditional_edges(join_name, next_branch_condition, path_map)

        return self._graph.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            callback_manager=callback_manager,
        )
Functions
__init__
__init__(state=None, context_manager=None, publisher=None, id_generator=DefaultIDGenerator(), container=None)
Source code in pyagenity/prebuilt/agent/branch_join.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def __init__(
    self,
    state: StateT | None = None,
    context_manager: BaseContextManager[StateT] | None = None,
    publisher: BasePublisher | None = None,
    id_generator: BaseIDGenerator = DefaultIDGenerator(),
    container: InjectQ | None = None,
):
    self._graph = StateGraph[StateT](
        state=state,
        context_manager=context_manager,
        publisher=publisher,
        id_generator=id_generator,
        container=container,
    )
compile
compile(branches, join_node, next_branch_condition=None, checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())
Source code in pyagenity/prebuilt/agent/branch_join.py
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def compile(
    self,
    branches: dict[str, Callable | tuple[Callable, str]],
    join_node: Callable | tuple[Callable, str],
    next_branch_condition: Callable | None = None,
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> CompiledGraph:
    if not branches:
        raise ValueError("branches must be a non-empty dict of name -> callable/tuple")

    # Add branch nodes
    branch_names = []
    for key, fn in branches.items():
        if isinstance(fn, tuple):
            branch_func, branch_name = fn
            if not callable(branch_func):
                raise ValueError(f"Branch '{key}'[0] must be callable")
        else:
            branch_func = fn
            branch_name = key
            if not callable(branch_func):
                raise ValueError(f"Branch '{key}' must be callable")
        self._graph.add_node(branch_name, branch_func)
        branch_names.append(branch_name)

    # Handle join_node
    if isinstance(join_node, tuple):
        join_func, join_name = join_node
        if not callable(join_func):
            raise ValueError("join_node[0] must be callable")
    else:
        join_func = join_node
        join_name = "JOIN"
        if not callable(join_func):
            raise ValueError("join_node must be callable")
    self._graph.add_node(join_name, join_func)

    # Wire branches to JOIN
    for name in branch_names:
        self._graph.add_edge(name, join_name)

    # Entry: first branch
    first = branch_names[0]
    self._graph.set_entry_point(first)

    # Decide next branch or END after join
    if next_branch_condition is None:
        # default: END after join
        def _cond(_: AgentState) -> str:
            return END

        next_branch_condition = _cond

    # next_branch_condition returns a branch name or END
    path_map = {k: k for k in branch_names}
    path_map[END] = END
    self._graph.add_conditional_edges(join_name, next_branch_condition, path_map)

    return self._graph.compile(
        checkpointer=checkpointer,
        store=store,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        callback_manager=callback_manager,
    )
DeepResearchAgent

Deep Research Agent: PLAN → RESEARCH → SYNTHESIZE → CRITIQUE loop.

This agent mirrors modern deep-research patterns inspired by DeerFlow and Tongyi DeepResearch: plan tasks, use tools to research, synthesize findings, critique gaps and iterate a bounded number of times.

Nodes: - PLAN: Decompose problem, propose search/tool actions; may include tool calls - RESEARCH: ToolNode executes search/browse/calc/etc tools - SYNTHESIZE: Aggregate and draft a coherent report or partial answer - CRITIQUE: Identify gaps, contradictions, or follow-ups; can request more tools

Routing:
- PLAN -> conditional(_route_after_plan):
    {"RESEARCH": RESEARCH, "SYNTHESIZE": SYNTHESIZE, END: END}
  • RESEARCH -> SYNTHESIZE
  • SYNTHESIZE -> CRITIQUE
  • CRITIQUE -> conditional(_route_after_critique): {"RESEARCH": RESEARCH, END: END}

Iteration Control: - Uses execution_meta.internal_data keys: dr_max_iters (int): maximum critique→research loops (default 2) dr_iters (int): current loop count (auto-updated) dr_heavy_mode (bool): if True, bias towards one more loop when critique suggests

Methods:

Name Description
__init__
compile
Source code in pyagenity/prebuilt/agent/deep_research.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
class DeepResearchAgent[StateT: AgentState]:
    """Deep Research Agent: PLAN → RESEARCH → SYNTHESIZE → CRITIQUE loop.

    This agent mirrors modern deep-research patterns inspired by DeerFlow and
    Tongyi DeepResearch: plan tasks, use tools to research, synthesize findings,
    critique gaps and iterate a bounded number of times.

    Nodes:
    - PLAN: Decompose problem, propose search/tool actions; may include tool calls
    - RESEARCH: ToolNode executes search/browse/calc/etc tools
    - SYNTHESIZE: Aggregate and draft a coherent report or partial answer
    - CRITIQUE: Identify gaps, contradictions, or follow-ups; can request more tools

        Routing:
        - PLAN -> conditional(_route_after_plan):
            {"RESEARCH": RESEARCH, "SYNTHESIZE": SYNTHESIZE, END: END}
    - RESEARCH -> SYNTHESIZE
    - SYNTHESIZE -> CRITIQUE
    - CRITIQUE -> conditional(_route_after_critique): {"RESEARCH": RESEARCH, END: END}

    Iteration Control:
    - Uses execution_meta.internal_data keys:
        dr_max_iters (int): maximum critique→research loops (default 2)
        dr_iters (int): current loop count (auto-updated)
        dr_heavy_mode (bool): if True, bias towards one more loop when critique suggests
    """

    def __init__(
        self,
        state: StateT | None = None,
        context_manager: BaseContextManager[StateT] | None = None,
        publisher: BasePublisher | None = None,
        id_generator: BaseIDGenerator = DefaultIDGenerator(),
        container: InjectQ | None = None,
        max_iters: int = 2,
        heavy_mode: bool = False,
    ):
        # initialize graph
        self._graph = StateGraph[StateT](
            state=state,
            context_manager=context_manager,
            publisher=publisher,
            id_generator=id_generator,
            container=container,
        )
        # seed default internal config on prototype state
        # Note: These values will be copied to new state at invoke time.
        exec_meta: ExecutionState = self._graph._state.execution_meta
        exec_meta.internal_data.setdefault("dr_max_iters", max(0, int(max_iters)))
        exec_meta.internal_data.setdefault("dr_iters", 0)
        exec_meta.internal_data.setdefault("dr_heavy_mode", bool(heavy_mode))

    def compile(  # noqa: PLR0912
        self,
        plan_node: Callable | tuple[Callable, str],
        research_tool_node: ToolNode | tuple[ToolNode, str],
        synthesize_node: Callable | tuple[Callable, str],
        critique_node: Callable | tuple[Callable, str],
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> CompiledGraph:
        # Handle plan_node
        if isinstance(plan_node, tuple):
            plan_func, plan_name = plan_node
            if not callable(plan_func):
                raise ValueError("plan_node[0] must be callable")
        else:
            plan_func = plan_node
            plan_name = "PLAN"
            if not callable(plan_func):
                raise ValueError("plan_node must be callable")

        # Handle research_tool_node
        if isinstance(research_tool_node, tuple):
            research_func, research_name = research_tool_node
            if not isinstance(research_func, ToolNode):
                raise ValueError("research_tool_node[0] must be a ToolNode")
        else:
            research_func = research_tool_node
            research_name = "RESEARCH"
            if not isinstance(research_func, ToolNode):
                raise ValueError("research_tool_node must be a ToolNode")

        # Handle synthesize_node
        if isinstance(synthesize_node, tuple):
            synthesize_func, synthesize_name = synthesize_node
            if not callable(synthesize_func):
                raise ValueError("synthesize_node[0] must be callable")
        else:
            synthesize_func = synthesize_node
            synthesize_name = "SYNTHESIZE"
            if not callable(synthesize_func):
                raise ValueError("synthesize_node must be callable")

        # Handle critique_node
        if isinstance(critique_node, tuple):
            critique_func, critique_name = critique_node
            if not callable(critique_func):
                raise ValueError("critique_node[0] must be callable")
        else:
            critique_func = critique_node
            critique_name = "CRITIQUE"
            if not callable(critique_func):
                raise ValueError("critique_node must be callable")

        # Add nodes
        self._graph.add_node(plan_name, plan_func)
        self._graph.add_node(research_name, research_func)
        self._graph.add_node(synthesize_name, synthesize_func)
        self._graph.add_node(critique_name, critique_func)

        # Edges
        self._graph.add_conditional_edges(
            plan_name,
            _route_after_plan,
            {research_name: research_name, synthesize_name: synthesize_name, END: END},
        )
        self._graph.add_edge(research_name, synthesize_name)
        self._graph.add_edge(synthesize_name, critique_name)
        self._graph.add_conditional_edges(
            critique_name,
            _route_after_critique,
            {research_name: research_name, END: END},
        )

        # Entry
        self._graph.set_entry_point(plan_name)

        return self._graph.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            callback_manager=callback_manager,
        )
Functions
__init__
__init__(state=None, context_manager=None, publisher=None, id_generator=DefaultIDGenerator(), container=None, max_iters=2, heavy_mode=False)
Source code in pyagenity/prebuilt/agent/deep_research.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def __init__(
    self,
    state: StateT | None = None,
    context_manager: BaseContextManager[StateT] | None = None,
    publisher: BasePublisher | None = None,
    id_generator: BaseIDGenerator = DefaultIDGenerator(),
    container: InjectQ | None = None,
    max_iters: int = 2,
    heavy_mode: bool = False,
):
    # initialize graph
    self._graph = StateGraph[StateT](
        state=state,
        context_manager=context_manager,
        publisher=publisher,
        id_generator=id_generator,
        container=container,
    )
    # seed default internal config on prototype state
    # Note: These values will be copied to new state at invoke time.
    exec_meta: ExecutionState = self._graph._state.execution_meta
    exec_meta.internal_data.setdefault("dr_max_iters", max(0, int(max_iters)))
    exec_meta.internal_data.setdefault("dr_iters", 0)
    exec_meta.internal_data.setdefault("dr_heavy_mode", bool(heavy_mode))
compile
compile(plan_node, research_tool_node, synthesize_node, critique_node, checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())
Source code in pyagenity/prebuilt/agent/deep_research.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def compile(  # noqa: PLR0912
    self,
    plan_node: Callable | tuple[Callable, str],
    research_tool_node: ToolNode | tuple[ToolNode, str],
    synthesize_node: Callable | tuple[Callable, str],
    critique_node: Callable | tuple[Callable, str],
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> CompiledGraph:
    # Handle plan_node
    if isinstance(plan_node, tuple):
        plan_func, plan_name = plan_node
        if not callable(plan_func):
            raise ValueError("plan_node[0] must be callable")
    else:
        plan_func = plan_node
        plan_name = "PLAN"
        if not callable(plan_func):
            raise ValueError("plan_node must be callable")

    # Handle research_tool_node
    if isinstance(research_tool_node, tuple):
        research_func, research_name = research_tool_node
        if not isinstance(research_func, ToolNode):
            raise ValueError("research_tool_node[0] must be a ToolNode")
    else:
        research_func = research_tool_node
        research_name = "RESEARCH"
        if not isinstance(research_func, ToolNode):
            raise ValueError("research_tool_node must be a ToolNode")

    # Handle synthesize_node
    if isinstance(synthesize_node, tuple):
        synthesize_func, synthesize_name = synthesize_node
        if not callable(synthesize_func):
            raise ValueError("synthesize_node[0] must be callable")
    else:
        synthesize_func = synthesize_node
        synthesize_name = "SYNTHESIZE"
        if not callable(synthesize_func):
            raise ValueError("synthesize_node must be callable")

    # Handle critique_node
    if isinstance(critique_node, tuple):
        critique_func, critique_name = critique_node
        if not callable(critique_func):
            raise ValueError("critique_node[0] must be callable")
    else:
        critique_func = critique_node
        critique_name = "CRITIQUE"
        if not callable(critique_func):
            raise ValueError("critique_node must be callable")

    # Add nodes
    self._graph.add_node(plan_name, plan_func)
    self._graph.add_node(research_name, research_func)
    self._graph.add_node(synthesize_name, synthesize_func)
    self._graph.add_node(critique_name, critique_func)

    # Edges
    self._graph.add_conditional_edges(
        plan_name,
        _route_after_plan,
        {research_name: research_name, synthesize_name: synthesize_name, END: END},
    )
    self._graph.add_edge(research_name, synthesize_name)
    self._graph.add_edge(synthesize_name, critique_name)
    self._graph.add_conditional_edges(
        critique_name,
        _route_after_critique,
        {research_name: research_name, END: END},
    )

    # Entry
    self._graph.set_entry_point(plan_name)

    return self._graph.compile(
        checkpointer=checkpointer,
        store=store,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        callback_manager=callback_manager,
    )
GuardedAgent

Validate output and repair until valid or attempts exhausted.

Nodes: - PRODUCE: main generation node - REPAIR: correction node when validation fails

Edges: PRODUCE -> conditional(valid? END : REPAIR) REPAIR -> PRODUCE

Methods:

Name Description
__init__
compile
Source code in pyagenity/prebuilt/agent/guarded.py
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
class GuardedAgent[StateT: AgentState]:
    """Validate output and repair until valid or attempts exhausted.

    Nodes:
    - PRODUCE: main generation node
    - REPAIR: correction node when validation fails

    Edges:
    PRODUCE -> conditional(valid? END : REPAIR)
    REPAIR -> PRODUCE
    """

    def __init__(
        self,
        state: StateT | None = None,
        context_manager: BaseContextManager[StateT] | None = None,
        publisher: BasePublisher | None = None,
        id_generator: BaseIDGenerator = DefaultIDGenerator(),
        container: InjectQ | None = None,
    ):
        self._graph = StateGraph[StateT](
            state=state,
            context_manager=context_manager,
            publisher=publisher,
            id_generator=id_generator,
            container=container,
        )

    def compile(
        self,
        produce_node: Callable | tuple[Callable, str],
        repair_node: Callable | tuple[Callable, str],
        validator: Callable[[AgentState], bool],
        max_attempts: int = 2,
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> CompiledGraph:
        # Handle produce_node
        if isinstance(produce_node, tuple):
            produce_func, produce_name = produce_node
            if not callable(produce_func):
                raise ValueError("produce_node[0] must be callable")
        else:
            produce_func = produce_node
            produce_name = "PRODUCE"
            if not callable(produce_func):
                raise ValueError("produce_node must be callable")

        # Handle repair_node
        if isinstance(repair_node, tuple):
            repair_func, repair_name = repair_node
            if not callable(repair_func):
                raise ValueError("repair_node[0] must be callable")
        else:
            repair_func = repair_node
            repair_name = "REPAIR"
            if not callable(repair_func):
                raise ValueError("repair_node must be callable")

        self._graph.add_node(produce_name, produce_func)
        self._graph.add_node(repair_name, repair_func)

        # produce -> END or REPAIR
        condition = _guard_condition_factory(validator, max_attempts)
        self._graph.add_conditional_edges(
            produce_name,
            condition,
            {repair_name: repair_name, END: END},
        )
        # repair -> produce
        self._graph.add_edge(repair_name, produce_name)

        self._graph.set_entry_point(produce_name)

        return self._graph.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            callback_manager=callback_manager,
        )
Functions
__init__
__init__(state=None, context_manager=None, publisher=None, id_generator=DefaultIDGenerator(), container=None)
Source code in pyagenity/prebuilt/agent/guarded.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def __init__(
    self,
    state: StateT | None = None,
    context_manager: BaseContextManager[StateT] | None = None,
    publisher: BasePublisher | None = None,
    id_generator: BaseIDGenerator = DefaultIDGenerator(),
    container: InjectQ | None = None,
):
    self._graph = StateGraph[StateT](
        state=state,
        context_manager=context_manager,
        publisher=publisher,
        id_generator=id_generator,
        container=container,
    )
compile
compile(produce_node, repair_node, validator, max_attempts=2, checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())
Source code in pyagenity/prebuilt/agent/guarded.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def compile(
    self,
    produce_node: Callable | tuple[Callable, str],
    repair_node: Callable | tuple[Callable, str],
    validator: Callable[[AgentState], bool],
    max_attempts: int = 2,
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> CompiledGraph:
    # Handle produce_node
    if isinstance(produce_node, tuple):
        produce_func, produce_name = produce_node
        if not callable(produce_func):
            raise ValueError("produce_node[0] must be callable")
    else:
        produce_func = produce_node
        produce_name = "PRODUCE"
        if not callable(produce_func):
            raise ValueError("produce_node must be callable")

    # Handle repair_node
    if isinstance(repair_node, tuple):
        repair_func, repair_name = repair_node
        if not callable(repair_func):
            raise ValueError("repair_node[0] must be callable")
    else:
        repair_func = repair_node
        repair_name = "REPAIR"
        if not callable(repair_func):
            raise ValueError("repair_node must be callable")

    self._graph.add_node(produce_name, produce_func)
    self._graph.add_node(repair_name, repair_func)

    # produce -> END or REPAIR
    condition = _guard_condition_factory(validator, max_attempts)
    self._graph.add_conditional_edges(
        produce_name,
        condition,
        {repair_name: repair_name, END: END},
    )
    # repair -> produce
    self._graph.add_edge(repair_name, produce_name)

    self._graph.set_entry_point(produce_name)

    return self._graph.compile(
        checkpointer=checkpointer,
        store=store,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        callback_manager=callback_manager,
    )
MapReduceAgent

Map over items then reduce.

Nodes: - SPLIT: optional, prepares per-item tasks (or state already contains items) - MAP: processes one item per iteration - REDUCE: aggregates results and decides END or continue

Compile requires

map_node: Callable|ToolNode reduce_node: Callable split_node: Callable | None condition: Callable[[AgentState], str] returns "MAP" to continue or END

Methods:

Name Description
__init__
compile
Source code in pyagenity/prebuilt/agent/map_reduce.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
class MapReduceAgent[StateT: AgentState]:
    """Map over items then reduce.

    Nodes:
    - SPLIT: optional, prepares per-item tasks (or state already contains items)
    - MAP: processes one item per iteration
    - REDUCE: aggregates results and decides END or continue

    Compile requires:
      map_node: Callable|ToolNode
      reduce_node: Callable
      split_node: Callable | None
      condition: Callable[[AgentState], str] returns "MAP" to continue or END
    """

    def __init__(
        self,
        state: StateT | None = None,
        context_manager: BaseContextManager[StateT] | None = None,
        publisher: BasePublisher | None = None,
        id_generator: BaseIDGenerator = DefaultIDGenerator(),
        container: InjectQ | None = None,
    ):
        self._graph = StateGraph[StateT](
            state=state,
            context_manager=context_manager,
            publisher=publisher,
            id_generator=id_generator,
            container=container,
        )

    def compile(  # noqa: PLR0912
        self,
        map_node: Callable | ToolNode | tuple[Callable | ToolNode, str],
        reduce_node: Callable | tuple[Callable, str],
        split_node: Callable | tuple[Callable, str] | None = None,
        condition: Callable[[AgentState], str] | None = None,
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> CompiledGraph:
        # Handle split_node
        split_name = "SPLIT"
        if split_node:
            if isinstance(split_node, tuple):
                split_func, split_name = split_node
                if not callable(split_func):
                    raise ValueError("split_node[0] must be callable")
            else:
                split_func = split_node
                split_name = "SPLIT"
                if not callable(split_func):
                    raise ValueError("split_node must be callable")
            self._graph.add_node(split_name, split_func)

        # Handle map_node
        if isinstance(map_node, tuple):
            map_func, map_name = map_node
            if not (callable(map_func) or isinstance(map_func, ToolNode)):
                raise ValueError("map_node[0] must be callable or ToolNode")
        else:
            map_func = map_node
            map_name = "MAP"
            if not (callable(map_func) or isinstance(map_func, ToolNode)):
                raise ValueError("map_node must be callable or ToolNode")
        self._graph.add_node(map_name, map_func)

        # Handle reduce_node
        if isinstance(reduce_node, tuple):
            reduce_func, reduce_name = reduce_node
            if not callable(reduce_func):
                raise ValueError("reduce_node[0] must be callable")
        else:
            reduce_func = reduce_node
            reduce_name = "REDUCE"
            if not callable(reduce_func):
                raise ValueError("reduce_node must be callable")
        self._graph.add_node(reduce_name, reduce_func)

        # Edges
        if split_node:
            self._graph.add_edge(split_name, map_name)
            self._graph.set_entry_point(split_name)
        else:
            self._graph.set_entry_point(map_name)

        self._graph.add_edge(map_name, reduce_name)

        # Continue mapping or finish
        if condition is None:
            # default: finish after one map-reduce
            def _cond(_: AgentState) -> str:
                return END

            condition = _cond

        self._graph.add_conditional_edges(reduce_name, condition, {map_name: map_name, END: END})

        return self._graph.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            callback_manager=callback_manager,
        )
Functions
__init__
__init__(state=None, context_manager=None, publisher=None, id_generator=DefaultIDGenerator(), container=None)
Source code in pyagenity/prebuilt/agent/map_reduce.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def __init__(
    self,
    state: StateT | None = None,
    context_manager: BaseContextManager[StateT] | None = None,
    publisher: BasePublisher | None = None,
    id_generator: BaseIDGenerator = DefaultIDGenerator(),
    container: InjectQ | None = None,
):
    self._graph = StateGraph[StateT](
        state=state,
        context_manager=context_manager,
        publisher=publisher,
        id_generator=id_generator,
        container=container,
    )
compile
compile(map_node, reduce_node, split_node=None, condition=None, checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())
Source code in pyagenity/prebuilt/agent/map_reduce.py
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def compile(  # noqa: PLR0912
    self,
    map_node: Callable | ToolNode | tuple[Callable | ToolNode, str],
    reduce_node: Callable | tuple[Callable, str],
    split_node: Callable | tuple[Callable, str] | None = None,
    condition: Callable[[AgentState], str] | None = None,
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> CompiledGraph:
    # Handle split_node
    split_name = "SPLIT"
    if split_node:
        if isinstance(split_node, tuple):
            split_func, split_name = split_node
            if not callable(split_func):
                raise ValueError("split_node[0] must be callable")
        else:
            split_func = split_node
            split_name = "SPLIT"
            if not callable(split_func):
                raise ValueError("split_node must be callable")
        self._graph.add_node(split_name, split_func)

    # Handle map_node
    if isinstance(map_node, tuple):
        map_func, map_name = map_node
        if not (callable(map_func) or isinstance(map_func, ToolNode)):
            raise ValueError("map_node[0] must be callable or ToolNode")
    else:
        map_func = map_node
        map_name = "MAP"
        if not (callable(map_func) or isinstance(map_func, ToolNode)):
            raise ValueError("map_node must be callable or ToolNode")
    self._graph.add_node(map_name, map_func)

    # Handle reduce_node
    if isinstance(reduce_node, tuple):
        reduce_func, reduce_name = reduce_node
        if not callable(reduce_func):
            raise ValueError("reduce_node[0] must be callable")
    else:
        reduce_func = reduce_node
        reduce_name = "REDUCE"
        if not callable(reduce_func):
            raise ValueError("reduce_node must be callable")
    self._graph.add_node(reduce_name, reduce_func)

    # Edges
    if split_node:
        self._graph.add_edge(split_name, map_name)
        self._graph.set_entry_point(split_name)
    else:
        self._graph.set_entry_point(map_name)

    self._graph.add_edge(map_name, reduce_name)

    # Continue mapping or finish
    if condition is None:
        # default: finish after one map-reduce
        def _cond(_: AgentState) -> str:
            return END

        condition = _cond

    self._graph.add_conditional_edges(reduce_name, condition, {map_name: map_name, END: END})

    return self._graph.compile(
        checkpointer=checkpointer,
        store=store,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        callback_manager=callback_manager,
    )
NetworkAgent

Network pattern: define arbitrary node set and routing policies.

  • Nodes can be callables or ToolNode.
  • Edges can be static or conditional via a router function per node.
  • Entry point is explicit.

Methods:

Name Description
__init__
compile
Source code in pyagenity/prebuilt/agent/network.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
class NetworkAgent[StateT: AgentState]:
    """Network pattern: define arbitrary node set and routing policies.

    - Nodes can be callables or ToolNode.
    - Edges can be static or conditional via a router function per node.
    - Entry point is explicit.
    """

    def __init__(
        self,
        state: StateT | None = None,
        context_manager: BaseContextManager[StateT] | None = None,
        publisher: BasePublisher | None = None,
        id_generator: BaseIDGenerator = DefaultIDGenerator(),
        container: InjectQ | None = None,
    ):
        self._graph = StateGraph[StateT](
            state=state,
            context_manager=context_manager,
            publisher=publisher,
            id_generator=id_generator,
            container=container,
        )

    def compile(
        self,
        nodes: dict[str, Callable | ToolNode | tuple[Callable | ToolNode, str]],
        entry: str,
        static_edges: list[tuple[str, str]] | None = None,
        conditional_edges: list[tuple[str, Callable[[AgentState], str], dict[str, str]]]
        | None = None,
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> CompiledGraph:
        if not nodes:
            raise ValueError("nodes must be a non-empty dict")

        # Add nodes
        for key, fn in nodes.items():
            if isinstance(fn, tuple):
                func, name = fn
            else:
                func, name = fn, key
            if not (callable(func) or isinstance(func, ToolNode)):
                raise ValueError(f"Node '{key}' must be a callable or ToolNode")
            self._graph.add_node(name, func)

        if entry not in self._graph.nodes:
            raise ValueError(f"entry node '{entry}' must be present in nodes")

        # Static edges
        for src, dst in static_edges or []:
            if src not in self._graph.nodes or dst not in self._graph.nodes:
                raise ValueError(f"Invalid static edge {src}->{dst}: unknown node")
            self._graph.add_edge(src, dst)

        # Conditional edges
        for src, cond, pmap in conditional_edges or []:
            if src not in self._graph.nodes:
                raise ValueError(f"Invalid conditional edge: unknown node '{src}'")
            self._graph.add_conditional_edges(src, cond, pmap)

        # Note: callers may include END in path maps; not enforced here.

        self._graph.set_entry_point(entry)

        return self._graph.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            callback_manager=callback_manager,
        )
Functions
__init__
__init__(state=None, context_manager=None, publisher=None, id_generator=DefaultIDGenerator(), container=None)
Source code in pyagenity/prebuilt/agent/network.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def __init__(
    self,
    state: StateT | None = None,
    context_manager: BaseContextManager[StateT] | None = None,
    publisher: BasePublisher | None = None,
    id_generator: BaseIDGenerator = DefaultIDGenerator(),
    container: InjectQ | None = None,
):
    self._graph = StateGraph[StateT](
        state=state,
        context_manager=context_manager,
        publisher=publisher,
        id_generator=id_generator,
        container=container,
    )
compile
compile(nodes, entry, static_edges=None, conditional_edges=None, checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())
Source code in pyagenity/prebuilt/agent/network.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def compile(
    self,
    nodes: dict[str, Callable | ToolNode | tuple[Callable | ToolNode, str]],
    entry: str,
    static_edges: list[tuple[str, str]] | None = None,
    conditional_edges: list[tuple[str, Callable[[AgentState], str], dict[str, str]]]
    | None = None,
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> CompiledGraph:
    if not nodes:
        raise ValueError("nodes must be a non-empty dict")

    # Add nodes
    for key, fn in nodes.items():
        if isinstance(fn, tuple):
            func, name = fn
        else:
            func, name = fn, key
        if not (callable(func) or isinstance(func, ToolNode)):
            raise ValueError(f"Node '{key}' must be a callable or ToolNode")
        self._graph.add_node(name, func)

    if entry not in self._graph.nodes:
        raise ValueError(f"entry node '{entry}' must be present in nodes")

    # Static edges
    for src, dst in static_edges or []:
        if src not in self._graph.nodes or dst not in self._graph.nodes:
            raise ValueError(f"Invalid static edge {src}->{dst}: unknown node")
        self._graph.add_edge(src, dst)

    # Conditional edges
    for src, cond, pmap in conditional_edges or []:
        if src not in self._graph.nodes:
            raise ValueError(f"Invalid conditional edge: unknown node '{src}'")
        self._graph.add_conditional_edges(src, cond, pmap)

    # Note: callers may include END in path maps; not enforced here.

    self._graph.set_entry_point(entry)

    return self._graph.compile(
        checkpointer=checkpointer,
        store=store,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        callback_manager=callback_manager,
    )
PlanActReflectAgent

Plan -> Act -> Reflect looping agent.

Pattern

PLAN -> (condition) -> ACT | REFLECT | END ACT -> REFLECT REFLECT -> PLAN

Default condition (_should_act): - If last assistant message contains tool calls -> ACT - If last message is from a tool -> REFLECT - Else -> END

Provide a custom condition to override this heuristic and implement
  • Budget / depth limiting
  • Confidence-based early stop
  • Dynamic branch selection (e.g., different tool nodes)

Parameters (constructor): state: Optional initial state instance context_manager: Custom context manager publisher: Optional publisher for streaming / events id_generator: ID generation strategy container: InjectQ DI container

compile(...) arguments: plan_node: Callable (state -> state). Produces next thought / tool requests tool_node: ToolNode executing declared tools reflect_node: Callable (state -> state). Consumes tool results & may adjust plan condition: Optional Callable[[AgentState], str] returning next node name or END checkpointer/store/interrupt_before/interrupt_after/callback_manager: Standard graph compilation options

Returns:

Type Description

CompiledGraph ready for invoke / ainvoke.

Notes
  • Node names can be customized via (callable, "NAME") tuples.
  • condition must return one of: tool_node_name, reflect_node_name, END.

Methods:

Name Description
__init__
compile

Compile the Plan-Act-Reflect loop.

Source code in pyagenity/prebuilt/agent/plan_act_reflect.py
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
class PlanActReflectAgent[StateT: AgentState]:
    """Plan -> Act -> Reflect looping agent.

    Pattern:
        PLAN -> (condition) -> ACT | REFLECT | END
        ACT -> REFLECT
        REFLECT -> PLAN

    Default condition (_should_act):
        - If last assistant message contains tool calls -> ACT
        - If last message is from a tool -> REFLECT
        - Else -> END

    Provide a custom condition to override this heuristic and implement:
        * Budget / depth limiting
        * Confidence-based early stop
        * Dynamic branch selection (e.g., different tool nodes)

    Parameters (constructor):
        state: Optional initial state instance
        context_manager: Custom context manager
        publisher: Optional publisher for streaming / events
        id_generator: ID generation strategy
        container: InjectQ DI container

    compile(...) arguments:
        plan_node: Callable (state -> state). Produces next thought / tool requests
        tool_node: ToolNode executing declared tools
        reflect_node: Callable (state -> state). Consumes tool results & may adjust plan
        condition: Optional Callable[[AgentState], str] returning next node name or END
        checkpointer/store/interrupt_before/interrupt_after/callback_manager:
            Standard graph compilation options

    Returns:
        CompiledGraph ready for invoke / ainvoke.

    Notes:
        - Node names can be customized via (callable, "NAME") tuples.
        - condition must return one of: tool_node_name, reflect_node_name, END.
    """

    def __init__(
        self,
        state: StateT | None = None,
        context_manager: BaseContextManager[StateT] | None = None,
        publisher: BasePublisher | None = None,
        id_generator: BaseIDGenerator = DefaultIDGenerator(),
        container: InjectQ | None = None,
    ):
        self._graph = StateGraph[StateT](
            state=state,
            context_manager=context_manager,
            publisher=publisher,
            id_generator=id_generator,
            container=container,
        )

    def compile(
        self,
        plan_node: Callable | tuple[Callable, str],
        tool_node: ToolNode | tuple[ToolNode, str],
        reflect_node: Callable | tuple[Callable, str],
        *,
        condition: Callable[[AgentState], str] | None = None,
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> CompiledGraph:
        """Compile the Plan-Act-Reflect loop.

        Args:
            plan_node: Callable or (callable, name)
            tool_node: ToolNode or (ToolNode, name)
            reflect_node: Callable or (callable, name)
            condition: Optional decision function. Defaults to internal heuristic.
            checkpointer/store/interrupt_* / callback_manager: Standard graph options.

        Returns:
            CompiledGraph
        """
        # PLAN
        if isinstance(plan_node, tuple):
            plan_func, plan_name = plan_node
            if not callable(plan_func):
                raise ValueError("plan_node[0] must be callable")
        else:
            plan_func = plan_node
            plan_name = "PLAN"
            if not callable(plan_func):
                raise ValueError("plan_node must be callable")

        # ACT
        if isinstance(tool_node, tuple):
            tool_func, tool_name = tool_node
            if not isinstance(tool_func, ToolNode):
                raise ValueError("tool_node[0] must be a ToolNode")
        else:
            tool_func = tool_node
            tool_name = "ACT"
            if not isinstance(tool_func, ToolNode):
                raise ValueError("tool_node must be a ToolNode")

        # REFLECT
        if isinstance(reflect_node, tuple):
            reflect_func, reflect_name = reflect_node
            if not callable(reflect_func):
                raise ValueError("reflect_node[0] must be callable")
        else:
            reflect_func = reflect_node
            reflect_name = "REFLECT"
            if not callable(reflect_func):
                raise ValueError("reflect_node must be callable")

        # Register nodes
        self._graph.add_node(plan_name, plan_func)
        self._graph.add_node(tool_name, tool_func)
        self._graph.add_node(reflect_name, reflect_func)

        # Decision
        decision_fn = condition or _should_act
        self._graph.add_conditional_edges(
            plan_name,
            decision_fn,
            {tool_name: tool_name, reflect_name: reflect_name, END: END},
        )

        # Loop edges
        self._graph.add_edge(tool_name, reflect_name)
        self._graph.add_edge(reflect_name, plan_name)

        # Entry
        self._graph.set_entry_point(plan_name)

        return self._graph.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            callback_manager=callback_manager,
        )
Functions
__init__
__init__(state=None, context_manager=None, publisher=None, id_generator=DefaultIDGenerator(), container=None)
Source code in pyagenity/prebuilt/agent/plan_act_reflect.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def __init__(
    self,
    state: StateT | None = None,
    context_manager: BaseContextManager[StateT] | None = None,
    publisher: BasePublisher | None = None,
    id_generator: BaseIDGenerator = DefaultIDGenerator(),
    container: InjectQ | None = None,
):
    self._graph = StateGraph[StateT](
        state=state,
        context_manager=context_manager,
        publisher=publisher,
        id_generator=id_generator,
        container=container,
    )
compile
compile(plan_node, tool_node, reflect_node, *, condition=None, checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())

Compile the Plan-Act-Reflect loop.

Parameters:

Name Type Description Default
plan_node Callable | tuple[Callable, str]

Callable or (callable, name)

required
tool_node ToolNode | tuple[ToolNode, str]

ToolNode or (ToolNode, name)

required
reflect_node Callable | tuple[Callable, str]

Callable or (callable, name)

required
condition Callable[[AgentState], str] | None

Optional decision function. Defaults to internal heuristic.

None
checkpointer/store/interrupt_* / callback_manager

Standard graph options.

required

Returns:

Type Description
CompiledGraph

CompiledGraph

Source code in pyagenity/prebuilt/agent/plan_act_reflect.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def compile(
    self,
    plan_node: Callable | tuple[Callable, str],
    tool_node: ToolNode | tuple[ToolNode, str],
    reflect_node: Callable | tuple[Callable, str],
    *,
    condition: Callable[[AgentState], str] | None = None,
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> CompiledGraph:
    """Compile the Plan-Act-Reflect loop.

    Args:
        plan_node: Callable or (callable, name)
        tool_node: ToolNode or (ToolNode, name)
        reflect_node: Callable or (callable, name)
        condition: Optional decision function. Defaults to internal heuristic.
        checkpointer/store/interrupt_* / callback_manager: Standard graph options.

    Returns:
        CompiledGraph
    """
    # PLAN
    if isinstance(plan_node, tuple):
        plan_func, plan_name = plan_node
        if not callable(plan_func):
            raise ValueError("plan_node[0] must be callable")
    else:
        plan_func = plan_node
        plan_name = "PLAN"
        if not callable(plan_func):
            raise ValueError("plan_node must be callable")

    # ACT
    if isinstance(tool_node, tuple):
        tool_func, tool_name = tool_node
        if not isinstance(tool_func, ToolNode):
            raise ValueError("tool_node[0] must be a ToolNode")
    else:
        tool_func = tool_node
        tool_name = "ACT"
        if not isinstance(tool_func, ToolNode):
            raise ValueError("tool_node must be a ToolNode")

    # REFLECT
    if isinstance(reflect_node, tuple):
        reflect_func, reflect_name = reflect_node
        if not callable(reflect_func):
            raise ValueError("reflect_node[0] must be callable")
    else:
        reflect_func = reflect_node
        reflect_name = "REFLECT"
        if not callable(reflect_func):
            raise ValueError("reflect_node must be callable")

    # Register nodes
    self._graph.add_node(plan_name, plan_func)
    self._graph.add_node(tool_name, tool_func)
    self._graph.add_node(reflect_name, reflect_func)

    # Decision
    decision_fn = condition or _should_act
    self._graph.add_conditional_edges(
        plan_name,
        decision_fn,
        {tool_name: tool_name, reflect_name: reflect_name, END: END},
    )

    # Loop edges
    self._graph.add_edge(tool_name, reflect_name)
    self._graph.add_edge(reflect_name, plan_name)

    # Entry
    self._graph.set_entry_point(plan_name)

    return self._graph.compile(
        checkpointer=checkpointer,
        store=store,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        callback_manager=callback_manager,
    )
RAGAgent

Simple RAG: retrieve -> synthesize; optional follow-up.

Nodes: - RETRIEVE: uses a retriever (callable or ToolNode) to fetch context - SYNTHESIZE: LLM/composer builds an answer - Optional condition: loop back to RETRIEVE for follow-up queries; else END

Methods:

Name Description
__init__
compile
compile_advanced

Advanced RAG wiring with hybrid retrieval and optional stages.

Source code in pyagenity/prebuilt/agent/rag.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
class RAGAgent[StateT: AgentState]:
    """Simple RAG: retrieve -> synthesize; optional follow-up.

    Nodes:
    - RETRIEVE: uses a retriever (callable or ToolNode) to fetch context
    - SYNTHESIZE: LLM/composer builds an answer
    - Optional condition: loop back to RETRIEVE for follow-up queries; else END
    """

    def __init__(
        self,
        state: StateT | None = None,
        context_manager: BaseContextManager[StateT] | None = None,
        publisher: BasePublisher | None = None,
        id_generator: BaseIDGenerator = DefaultIDGenerator(),
        container: InjectQ | None = None,
    ):
        self._graph = StateGraph[StateT](
            state=state,
            context_manager=context_manager,
            publisher=publisher,
            id_generator=id_generator,
            container=container,
        )

    def compile(
        self,
        retriever_node: Callable | ToolNode | tuple[Callable | ToolNode, str],
        synthesize_node: Callable | tuple[Callable, str],
        followup_condition: Callable[[AgentState], str] | None = None,
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> CompiledGraph:
        # Nodes
        # Handle retriever_node
        if isinstance(retriever_node, tuple):
            retriever_func, retriever_name = retriever_node
            if not (callable(retriever_func) or isinstance(retriever_func, ToolNode)):
                raise ValueError("retriever_node[0] must be callable or ToolNode")
        else:
            retriever_func = retriever_node
            retriever_name = "RETRIEVE"
            if not (callable(retriever_func) or isinstance(retriever_func, ToolNode)):
                raise ValueError("retriever_node must be callable or ToolNode")

        # Handle synthesize_node
        if isinstance(synthesize_node, tuple):
            synthesize_func, synthesize_name = synthesize_node
            if not callable(synthesize_func):
                raise ValueError("synthesize_node[0] must be callable")
        else:
            synthesize_func = synthesize_node
            synthesize_name = "SYNTHESIZE"
            if not callable(synthesize_func):
                raise ValueError("synthesize_node must be callable")

        self._graph.add_node(retriever_name, retriever_func)  # type: ignore[arg-type]
        self._graph.add_node(synthesize_name, synthesize_func)

        # Edges
        self._graph.add_edge(retriever_name, synthesize_name)
        self._graph.set_entry_point(retriever_name)

        if followup_condition is None:
            # default: END after synthesize
            def _cond(_: AgentState) -> str:
                return END

            followup_condition = _cond

        self._graph.add_conditional_edges(
            synthesize_name,
            followup_condition,
            {retriever_name: retriever_name, END: END},
        )

        return self._graph.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            callback_manager=callback_manager,
        )

    def compile_advanced(
        self,
        retriever_nodes: list[Callable | ToolNode | tuple[Callable | ToolNode, str]],
        synthesize_node: Callable | tuple[Callable, str],
        options: dict | None = None,
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> CompiledGraph:
        """Advanced RAG wiring with hybrid retrieval and optional stages.

        Chain:
          (QUERY_PLAN?) -> R1 -> (MERGE?) -> R2 -> (MERGE?) -> ...
          -> (RERANK?) -> (COMPRESS?) -> SYNTHESIZE -> cond
        Each retriever may be a different modality (sparse, dense, self-query, MMR, etc.).
        """

        options = options or {}
        query_plan_node = options.get("query_plan")
        merger_node = options.get("merge")
        rerank_node = options.get("rerank")
        compress_node = options.get("compress")
        followup_condition = options.get("followup_condition")

        qname = self._add_optional_node(
            query_plan_node,
            default_name="QUERY_PLAN",
            label="query_plan",
        )

        # Add retrievers
        r_names = self._add_retriever_nodes(retriever_nodes)

        # Optional stages
        mname = self._add_optional_node(merger_node, default_name="MERGE", label="merge")
        rrname = self._add_optional_node(rerank_node, default_name="RERANK", label="rerank")
        cname = self._add_optional_node(
            compress_node,
            default_name="COMPRESS",
            label="compress",
        )

        # Synthesize
        sname = self._add_synthesize_node(synthesize_node)

        # Wire edges end-to-end and follow-up
        self._wire_advanced_edges(qname, r_names, mname, rrname, cname, sname, followup_condition)

        return self._graph.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            callback_manager=callback_manager,
        )

    # ---- helpers ----
    def _add_optional_node(
        self,
        node: Callable | tuple[Callable, str] | None,
        *,
        default_name: str,
        label: str,
    ) -> str | None:
        if not node:
            return None
        if isinstance(node, tuple):
            func, name = node
        else:
            func, name = node, default_name
        if not callable(func):
            raise ValueError(f"{label} node must be callable")
        self._graph.add_node(name, func)
        return name

    def _add_retriever_nodes(
        self,
        retriever_nodes: list[Callable | ToolNode | tuple[Callable | ToolNode, str]],
    ) -> list[str]:
        if not retriever_nodes:
            raise ValueError("retriever_nodes must be non-empty")
        names: list[str] = []
        for idx, rn in enumerate(retriever_nodes):
            if isinstance(rn, tuple):
                rfunc, rname = rn
            else:
                rfunc, rname = rn, f"RETRIEVE_{idx + 1}"
            if not (callable(rfunc) or isinstance(rfunc, ToolNode)):
                raise ValueError("retriever must be callable or ToolNode")
            self._graph.add_node(rname, rfunc)  # type: ignore[arg-type]
            names.append(rname)
        return names

    def _add_synthesize_node(self, synthesize_node: Callable | tuple[Callable, str]) -> str:
        if isinstance(synthesize_node, tuple):
            sfunc, sname = synthesize_node
        else:
            sfunc, sname = synthesize_node, "SYNTHESIZE"
        if not callable(sfunc):
            raise ValueError("synthesize_node must be callable")
        self._graph.add_node(sname, sfunc)
        return sname

    def _wire_advanced_edges(
        self,
        qname: str | None,
        r_names: list[str],
        mname: str | None,
        rrname: str | None,
        cname: str | None,
        sname: str,
        followup_condition: Callable[[AgentState], str] | None = None,
    ) -> None:
        entry = qname or r_names[0]
        self._graph.set_entry_point(entry)
        if qname:
            self._graph.add_edge(qname, r_names[0])

        tail_target = rrname or cname or sname
        for i, rname in enumerate(r_names):
            is_last = i == len(r_names) - 1
            nxt = r_names[i + 1] if not is_last else tail_target
            if mname:
                self._graph.add_edge(rname, mname)
                self._graph.add_edge(mname, nxt)
            else:
                self._graph.add_edge(rname, nxt)

        # Tail wiring
        if rrname and cname:
            self._graph.add_edge(rrname, cname)
            self._graph.add_edge(cname, sname)
        elif rrname:
            self._graph.add_edge(rrname, sname)
        elif cname:
            self._graph.add_edge(cname, sname)

        # default follow-up to END
        if followup_condition is None:

            def _cond(_: AgentState) -> str:
                return END

            followup_condition = _cond

        entry_node = qname or r_names[0]
        path_map = {entry_node: entry_node, END: END}
        self._graph.add_conditional_edges(sname, followup_condition, path_map)
Functions
__init__
__init__(state=None, context_manager=None, publisher=None, id_generator=DefaultIDGenerator(), container=None)
Source code in pyagenity/prebuilt/agent/rag.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def __init__(
    self,
    state: StateT | None = None,
    context_manager: BaseContextManager[StateT] | None = None,
    publisher: BasePublisher | None = None,
    id_generator: BaseIDGenerator = DefaultIDGenerator(),
    container: InjectQ | None = None,
):
    self._graph = StateGraph[StateT](
        state=state,
        context_manager=context_manager,
        publisher=publisher,
        id_generator=id_generator,
        container=container,
    )
compile
compile(retriever_node, synthesize_node, followup_condition=None, checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())
Source code in pyagenity/prebuilt/agent/rag.py
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def compile(
    self,
    retriever_node: Callable | ToolNode | tuple[Callable | ToolNode, str],
    synthesize_node: Callable | tuple[Callable, str],
    followup_condition: Callable[[AgentState], str] | None = None,
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> CompiledGraph:
    # Nodes
    # Handle retriever_node
    if isinstance(retriever_node, tuple):
        retriever_func, retriever_name = retriever_node
        if not (callable(retriever_func) or isinstance(retriever_func, ToolNode)):
            raise ValueError("retriever_node[0] must be callable or ToolNode")
    else:
        retriever_func = retriever_node
        retriever_name = "RETRIEVE"
        if not (callable(retriever_func) or isinstance(retriever_func, ToolNode)):
            raise ValueError("retriever_node must be callable or ToolNode")

    # Handle synthesize_node
    if isinstance(synthesize_node, tuple):
        synthesize_func, synthesize_name = synthesize_node
        if not callable(synthesize_func):
            raise ValueError("synthesize_node[0] must be callable")
    else:
        synthesize_func = synthesize_node
        synthesize_name = "SYNTHESIZE"
        if not callable(synthesize_func):
            raise ValueError("synthesize_node must be callable")

    self._graph.add_node(retriever_name, retriever_func)  # type: ignore[arg-type]
    self._graph.add_node(synthesize_name, synthesize_func)

    # Edges
    self._graph.add_edge(retriever_name, synthesize_name)
    self._graph.set_entry_point(retriever_name)

    if followup_condition is None:
        # default: END after synthesize
        def _cond(_: AgentState) -> str:
            return END

        followup_condition = _cond

    self._graph.add_conditional_edges(
        synthesize_name,
        followup_condition,
        {retriever_name: retriever_name, END: END},
    )

    return self._graph.compile(
        checkpointer=checkpointer,
        store=store,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        callback_manager=callback_manager,
    )
compile_advanced
compile_advanced(retriever_nodes, synthesize_node, options=None, checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())

Advanced RAG wiring with hybrid retrieval and optional stages.

Chain

(QUERY_PLAN?) -> R1 -> (MERGE?) -> R2 -> (MERGE?) -> ... -> (RERANK?) -> (COMPRESS?) -> SYNTHESIZE -> cond

Each retriever may be a different modality (sparse, dense, self-query, MMR, etc.).

Source code in pyagenity/prebuilt/agent/rag.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def compile_advanced(
    self,
    retriever_nodes: list[Callable | ToolNode | tuple[Callable | ToolNode, str]],
    synthesize_node: Callable | tuple[Callable, str],
    options: dict | None = None,
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> CompiledGraph:
    """Advanced RAG wiring with hybrid retrieval and optional stages.

    Chain:
      (QUERY_PLAN?) -> R1 -> (MERGE?) -> R2 -> (MERGE?) -> ...
      -> (RERANK?) -> (COMPRESS?) -> SYNTHESIZE -> cond
    Each retriever may be a different modality (sparse, dense, self-query, MMR, etc.).
    """

    options = options or {}
    query_plan_node = options.get("query_plan")
    merger_node = options.get("merge")
    rerank_node = options.get("rerank")
    compress_node = options.get("compress")
    followup_condition = options.get("followup_condition")

    qname = self._add_optional_node(
        query_plan_node,
        default_name="QUERY_PLAN",
        label="query_plan",
    )

    # Add retrievers
    r_names = self._add_retriever_nodes(retriever_nodes)

    # Optional stages
    mname = self._add_optional_node(merger_node, default_name="MERGE", label="merge")
    rrname = self._add_optional_node(rerank_node, default_name="RERANK", label="rerank")
    cname = self._add_optional_node(
        compress_node,
        default_name="COMPRESS",
        label="compress",
    )

    # Synthesize
    sname = self._add_synthesize_node(synthesize_node)

    # Wire edges end-to-end and follow-up
    self._wire_advanced_edges(qname, r_names, mname, rrname, cname, sname, followup_condition)

    return self._graph.compile(
        checkpointer=checkpointer,
        store=store,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        callback_manager=callback_manager,
    )
ReactAgent

Methods:

Name Description
__init__
compile
Source code in pyagenity/prebuilt/agent/react.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
class ReactAgent[StateT: AgentState]:
    def __init__(
        self,
        state: StateT | None = None,
        context_manager: BaseContextManager[StateT] | None = None,
        publisher: BasePublisher | None = None,
        id_generator: BaseIDGenerator = DefaultIDGenerator(),
        container: InjectQ | None = None,
    ):
        self._graph = StateGraph[StateT](
            state=state,
            context_manager=context_manager,
            publisher=publisher,
            id_generator=id_generator,
            container=container,
        )

    def compile(
        self,
        main_node: tuple[Callable, str] | Callable,
        tool_node: tuple[Callable, str] | Callable,
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> CompiledGraph:
        # Determine main node function and name
        if isinstance(main_node, tuple):
            main_func, main_name = main_node
            if not callable(main_func):
                raise ValueError("main_node[0] must be a callable function")
        else:
            main_func = main_node
            main_name = "MAIN"
            if not callable(main_func):
                raise ValueError("main_node must be a callable function")

        # Determine tool node function and name
        if isinstance(tool_node, tuple):
            tool_func, tool_name = tool_node
            # Accept both callable functions and ToolNode instances
            if not callable(tool_func) and not hasattr(tool_func, "invoke"):
                raise ValueError("tool_node[0] must be a callable function or ToolNode")
        else:
            tool_func = tool_node
            tool_name = "TOOL"
            # Accept both callable functions and ToolNode instances
            # ToolNode instances have an 'invoke' method but are not callable
            if not callable(tool_func) and not hasattr(tool_func, "invoke"):
                raise ValueError("tool_node must be a callable function or ToolNode instance")

        self._graph.add_node(main_name, main_func)
        self._graph.add_node(tool_name, tool_func)

        # Now create edges
        self._graph.add_conditional_edges(
            main_name,
            _should_use_tools,
            {tool_name: tool_name, END: END},
        )

        self._graph.add_edge(tool_name, main_name)
        self._graph.set_entry_point(main_name)

        return self._graph.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            callback_manager=callback_manager,
        )
Functions
__init__
__init__(state=None, context_manager=None, publisher=None, id_generator=DefaultIDGenerator(), container=None)
Source code in pyagenity/prebuilt/agent/react.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def __init__(
    self,
    state: StateT | None = None,
    context_manager: BaseContextManager[StateT] | None = None,
    publisher: BasePublisher | None = None,
    id_generator: BaseIDGenerator = DefaultIDGenerator(),
    container: InjectQ | None = None,
):
    self._graph = StateGraph[StateT](
        state=state,
        context_manager=context_manager,
        publisher=publisher,
        id_generator=id_generator,
        container=container,
    )
compile
compile(main_node, tool_node, checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())
Source code in pyagenity/prebuilt/agent/react.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def compile(
    self,
    main_node: tuple[Callable, str] | Callable,
    tool_node: tuple[Callable, str] | Callable,
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> CompiledGraph:
    # Determine main node function and name
    if isinstance(main_node, tuple):
        main_func, main_name = main_node
        if not callable(main_func):
            raise ValueError("main_node[0] must be a callable function")
    else:
        main_func = main_node
        main_name = "MAIN"
        if not callable(main_func):
            raise ValueError("main_node must be a callable function")

    # Determine tool node function and name
    if isinstance(tool_node, tuple):
        tool_func, tool_name = tool_node
        # Accept both callable functions and ToolNode instances
        if not callable(tool_func) and not hasattr(tool_func, "invoke"):
            raise ValueError("tool_node[0] must be a callable function or ToolNode")
    else:
        tool_func = tool_node
        tool_name = "TOOL"
        # Accept both callable functions and ToolNode instances
        # ToolNode instances have an 'invoke' method but are not callable
        if not callable(tool_func) and not hasattr(tool_func, "invoke"):
            raise ValueError("tool_node must be a callable function or ToolNode instance")

    self._graph.add_node(main_name, main_func)
    self._graph.add_node(tool_name, tool_func)

    # Now create edges
    self._graph.add_conditional_edges(
        main_name,
        _should_use_tools,
        {tool_name: tool_name, END: END},
    )

    self._graph.add_edge(tool_name, main_name)
    self._graph.set_entry_point(main_name)

    return self._graph.compile(
        checkpointer=checkpointer,
        store=store,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        callback_manager=callback_manager,
    )
RouterAgent

A configurable router-style agent.

Pattern: - A router node runs (LLM or custom logic) and may update state/messages - A condition function inspects the state and returns a route key - Edges route to the matching node; each route returns back to ROUTER - Return END (via condition) to finish

Usage

router = RouterAgent() app = router.compile( router_node=my_router_func, # def my_router_func(state, config, ...) routes={ "search": search_node, "summarize": summarize_node, }, # Condition inspects state and returns one of the keys above or END condition=my_condition, # def my_condition(state) -> str # Optional explicit path map if returned keys differ from node names # path_map={"SEARCH": "search", "SUM": "summarize", END: END} )

Methods:

Name Description
__init__
compile
Source code in pyagenity/prebuilt/agent/router.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
class RouterAgent[StateT: AgentState]:
    """A configurable router-style agent.

    Pattern:
    - A router node runs (LLM or custom logic) and may update state/messages
    - A condition function inspects the state and returns a route key
    - Edges route to the matching node; each route returns back to ROUTER
    - Return END (via condition) to finish

    Usage:
        router = RouterAgent()
        app = router.compile(
            router_node=my_router_func,  # def my_router_func(state, config, ...)
            routes={
                "search": search_node,
                "summarize": summarize_node,
            },
            # Condition inspects state and returns one of the keys above or END
            condition=my_condition,  # def my_condition(state) -> str
            # Optional explicit path map if returned keys differ from node names
            # path_map={"SEARCH": "search", "SUM": "summarize", END: END}
        )
    """

    def __init__(
        self,
        state: StateT | None = None,
        context_manager: BaseContextManager[StateT] | None = None,
        publisher: BasePublisher | None = None,
        id_generator: BaseIDGenerator = DefaultIDGenerator(),
        container: InjectQ | None = None,
    ):
        self._graph = StateGraph[StateT](
            state=state,
            context_manager=context_manager,
            publisher=publisher,
            id_generator=id_generator,
            container=container,
        )

    def compile(  # noqa: PLR0912
        self,
        router_node: Callable | tuple[Callable, str],
        routes: dict[str, Callable | ToolNode | tuple[Callable | ToolNode, str]],
        condition: Callable[[AgentState], str] | None = None,
        path_map: dict[str, str] | None = None,
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> CompiledGraph:
        # Handle router_node
        if isinstance(router_node, tuple):
            router_func, router_name = router_node
            if not callable(router_func):
                raise ValueError("router_node[0] must be callable")
        else:
            router_func = router_node
            router_name = "ROUTER"
            if not callable(router_func):
                raise ValueError("router_node must be callable")

        if not routes:
            raise ValueError("routes must be a non-empty dict of name -> callable/ToolNode/tuple")

        # Add route nodes
        route_names = []
        for key, func in routes.items():
            if isinstance(func, tuple):
                route_func, route_name = func
                if not (callable(route_func) or isinstance(route_func, ToolNode)):
                    raise ValueError(f"Route '{key}'[0] must be callable or ToolNode")
            else:
                route_func = func
                route_name = key
                if not (callable(route_func) or isinstance(route_func, ToolNode)):
                    raise ValueError(f"Route '{key}' must be callable or ToolNode")
            self._graph.add_node(route_name, route_func)
            route_names.append(route_name)

        # Add router node as entry
        self._graph.add_node(router_name, router_func)

        # Build default condition/path_map if needed
        if condition is None and len(route_names) == 1:
            only = route_names[0]

            def _always(_: AgentState) -> str:
                return only

            condition = _always
            path_map = {only: only, END: END}

        if condition is None and len(route_names) > 1:
            raise ValueError("condition must be provided when multiple routes are defined")

        # If path_map is not provided, assume router returns exact route names
        if path_map is None:
            path_map = {k: k for k in route_names}
            path_map[END] = END

        # Conditional edges from router node based on condition results
        self._graph.add_conditional_edges(
            router_name,
            condition,  # type: ignore[arg-type]
            path_map,
        )

        # Loop back to router node from each route node
        for name in route_names:
            self._graph.add_edge(name, router_name)

        # Entry
        self._graph.set_entry_point(router_name)

        return self._graph.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            callback_manager=callback_manager,
        )
Functions
__init__
__init__(state=None, context_manager=None, publisher=None, id_generator=DefaultIDGenerator(), container=None)
Source code in pyagenity/prebuilt/agent/router.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def __init__(
    self,
    state: StateT | None = None,
    context_manager: BaseContextManager[StateT] | None = None,
    publisher: BasePublisher | None = None,
    id_generator: BaseIDGenerator = DefaultIDGenerator(),
    container: InjectQ | None = None,
):
    self._graph = StateGraph[StateT](
        state=state,
        context_manager=context_manager,
        publisher=publisher,
        id_generator=id_generator,
        container=container,
    )
compile
compile(router_node, routes, condition=None, path_map=None, checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())
Source code in pyagenity/prebuilt/agent/router.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def compile(  # noqa: PLR0912
    self,
    router_node: Callable | tuple[Callable, str],
    routes: dict[str, Callable | ToolNode | tuple[Callable | ToolNode, str]],
    condition: Callable[[AgentState], str] | None = None,
    path_map: dict[str, str] | None = None,
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> CompiledGraph:
    # Handle router_node
    if isinstance(router_node, tuple):
        router_func, router_name = router_node
        if not callable(router_func):
            raise ValueError("router_node[0] must be callable")
    else:
        router_func = router_node
        router_name = "ROUTER"
        if not callable(router_func):
            raise ValueError("router_node must be callable")

    if not routes:
        raise ValueError("routes must be a non-empty dict of name -> callable/ToolNode/tuple")

    # Add route nodes
    route_names = []
    for key, func in routes.items():
        if isinstance(func, tuple):
            route_func, route_name = func
            if not (callable(route_func) or isinstance(route_func, ToolNode)):
                raise ValueError(f"Route '{key}'[0] must be callable or ToolNode")
        else:
            route_func = func
            route_name = key
            if not (callable(route_func) or isinstance(route_func, ToolNode)):
                raise ValueError(f"Route '{key}' must be callable or ToolNode")
        self._graph.add_node(route_name, route_func)
        route_names.append(route_name)

    # Add router node as entry
    self._graph.add_node(router_name, router_func)

    # Build default condition/path_map if needed
    if condition is None and len(route_names) == 1:
        only = route_names[0]

        def _always(_: AgentState) -> str:
            return only

        condition = _always
        path_map = {only: only, END: END}

    if condition is None and len(route_names) > 1:
        raise ValueError("condition must be provided when multiple routes are defined")

    # If path_map is not provided, assume router returns exact route names
    if path_map is None:
        path_map = {k: k for k in route_names}
        path_map[END] = END

    # Conditional edges from router node based on condition results
    self._graph.add_conditional_edges(
        router_name,
        condition,  # type: ignore[arg-type]
        path_map,
    )

    # Loop back to router node from each route node
    for name in route_names:
        self._graph.add_edge(name, router_name)

    # Entry
    self._graph.set_entry_point(router_name)

    return self._graph.compile(
        checkpointer=checkpointer,
        store=store,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        callback_manager=callback_manager,
    )
SequentialAgent

A simple sequential agent that executes a fixed pipeline of nodes.

Pattern: - Nodes run in the provided order: step1 -> step2 -> ... -> stepN - After the last step, the graph ends

Usage

seq = SequentialAgent() app = seq.compile([ ("ingest", ingest_node), ("plan", plan_node), ("execute", execute_node), ])

Methods:

Name Description
__init__
compile
Source code in pyagenity/prebuilt/agent/sequential.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
class SequentialAgent[StateT: AgentState]:
    """A simple sequential agent that executes a fixed pipeline of nodes.

    Pattern:
    - Nodes run in the provided order: step1 -> step2 -> ... -> stepN
    - After the last step, the graph ends

    Usage:
        seq = SequentialAgent()
        app = seq.compile([
            ("ingest", ingest_node),
            ("plan", plan_node),
            ("execute", execute_node),
        ])
    """

    def __init__(
        self,
        state: StateT | None = None,
        context_manager: BaseContextManager[StateT] | None = None,
        publisher: BasePublisher | None = None,
        id_generator: BaseIDGenerator = DefaultIDGenerator(),
        container: InjectQ | None = None,
    ):
        self._graph = StateGraph[StateT](
            state=state,
            context_manager=context_manager,
            publisher=publisher,
            id_generator=id_generator,
            container=container,
        )

    def compile(
        self,
        steps: Sequence[tuple[str, Callable | ToolNode] | tuple[Callable | ToolNode, str]],
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> CompiledGraph:
        if not steps or len(steps) == 0:
            raise ValueError(
                "steps must be a non-empty sequence of (name, callable/ToolNode) o"
                "or (callable/ToolNode, name)"
            )

        # Add nodes
        step_names = []
        for step in steps:
            if isinstance(step[0], str):
                name, func = step
            else:
                func, name = step
            if not (callable(func) or isinstance(func, ToolNode)):
                raise ValueError(f"Step '{name}' must be a callable or ToolNode")
            self._graph.add_node(name, func)  # type: ignore[arg-type]
            step_names.append(name)

        # Static edges in order
        for i in range(len(step_names) - 1):
            self._graph.add_edge(step_names[i], step_names[i + 1])

        # Entry is the first step
        self._graph.set_entry_point(step_names[0])

        # No explicit edge to END needed; the engine will end if no outgoing edges remain.
        return self._graph.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            callback_manager=callback_manager,
        )
Functions
__init__
__init__(state=None, context_manager=None, publisher=None, id_generator=DefaultIDGenerator(), container=None)
Source code in pyagenity/prebuilt/agent/sequential.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def __init__(
    self,
    state: StateT | None = None,
    context_manager: BaseContextManager[StateT] | None = None,
    publisher: BasePublisher | None = None,
    id_generator: BaseIDGenerator = DefaultIDGenerator(),
    container: InjectQ | None = None,
):
    self._graph = StateGraph[StateT](
        state=state,
        context_manager=context_manager,
        publisher=publisher,
        id_generator=id_generator,
        container=container,
    )
compile
compile(steps, checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())
Source code in pyagenity/prebuilt/agent/sequential.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def compile(
    self,
    steps: Sequence[tuple[str, Callable | ToolNode] | tuple[Callable | ToolNode, str]],
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> CompiledGraph:
    if not steps or len(steps) == 0:
        raise ValueError(
            "steps must be a non-empty sequence of (name, callable/ToolNode) o"
            "or (callable/ToolNode, name)"
        )

    # Add nodes
    step_names = []
    for step in steps:
        if isinstance(step[0], str):
            name, func = step
        else:
            func, name = step
        if not (callable(func) or isinstance(func, ToolNode)):
            raise ValueError(f"Step '{name}' must be a callable or ToolNode")
        self._graph.add_node(name, func)  # type: ignore[arg-type]
        step_names.append(name)

    # Static edges in order
    for i in range(len(step_names) - 1):
        self._graph.add_edge(step_names[i], step_names[i + 1])

    # Entry is the first step
    self._graph.set_entry_point(step_names[0])

    # No explicit edge to END needed; the engine will end if no outgoing edges remain.
    return self._graph.compile(
        checkpointer=checkpointer,
        store=store,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        callback_manager=callback_manager,
    )
SupervisorTeamAgent

Supervisor routes tasks to worker nodes and aggregates results.

Nodes: - SUPERVISOR: decides which worker to call (by returning a worker key) or END - Multiple WORKER nodes: functions or ToolNode instances - AGGREGATE: optional aggregator node after worker runs; loops back to SUPERVISOR

The compile requires

supervisor_node: Callable workers: dict[str, Callable|ToolNode] aggregate_node: Callable | None condition: Callable[[AgentState], str] returns worker key or END

Methods:

Name Description
__init__
compile
Source code in pyagenity/prebuilt/agent/supervisor_team.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
class SupervisorTeamAgent[StateT: AgentState]:
    """Supervisor routes tasks to worker nodes and aggregates results.

    Nodes:
    - SUPERVISOR: decides which worker to call (by returning a worker key) or END
    - Multiple WORKER nodes: functions or ToolNode instances
    - AGGREGATE: optional aggregator node after worker runs; loops back to SUPERVISOR

    The compile requires:
      supervisor_node: Callable
      workers: dict[str, Callable|ToolNode]
      aggregate_node: Callable | None
      condition: Callable[[AgentState], str] returns worker key or END
    """

    def __init__(
        self,
        state: StateT | None = None,
        context_manager: BaseContextManager[StateT] | None = None,
        publisher: BasePublisher | None = None,
        id_generator: BaseIDGenerator = DefaultIDGenerator(),
        container: InjectQ | None = None,
    ):
        self._graph = StateGraph[StateT](
            state=state,
            context_manager=context_manager,
            publisher=publisher,
            id_generator=id_generator,
            container=container,
        )

    def compile(  # noqa: PLR0912
        self,
        supervisor_node: Callable | tuple[Callable, str],
        workers: dict[str, Callable | ToolNode | tuple[Callable | ToolNode, str]],
        condition: Callable[[AgentState], str],
        aggregate_node: Callable | tuple[Callable, str] | None = None,
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> CompiledGraph:
        # Handle supervisor_node
        if isinstance(supervisor_node, tuple):
            supervisor_func, supervisor_name = supervisor_node
            if not callable(supervisor_func):
                raise ValueError("supervisor_node[0] must be callable")
        else:
            supervisor_func = supervisor_node
            supervisor_name = "SUPERVISOR"
            if not callable(supervisor_func):
                raise ValueError("supervisor_node must be callable")

        if not workers:
            raise ValueError("workers must be a non-empty dict")

        # Add worker nodes
        worker_names = []
        for key, fn in workers.items():
            if isinstance(fn, tuple):
                worker_func, worker_name = fn
                if not (callable(worker_func) or isinstance(worker_func, ToolNode)):
                    raise ValueError(f"Worker '{key}'[0] must be callable or ToolNode")
            else:
                worker_func = fn
                worker_name = key
                if not (callable(worker_func) or isinstance(worker_func, ToolNode)):
                    raise ValueError(f"Worker '{key}' must be callable or ToolNode")
            self._graph.add_node(worker_name, worker_func)
            worker_names.append(worker_name)

        # Handle aggregate_node
        aggregate_name = "AGGREGATE"
        if aggregate_node:
            if isinstance(aggregate_node, tuple):
                aggregate_func, aggregate_name = aggregate_node
                if not callable(aggregate_func):
                    raise ValueError("aggregate_node[0] must be callable")
            else:
                aggregate_func = aggregate_node
                aggregate_name = "AGGREGATE"
                if not callable(aggregate_func):
                    raise ValueError("aggregate_node must be callable")
            self._graph.add_node(aggregate_name, aggregate_func)

        # SUPERVISOR decides next worker
        path_map = {k: k for k in worker_names}
        path_map[END] = END
        self._graph.add_conditional_edges(supervisor_name, condition, path_map)

        # After worker, go to AGGREGATE if present, else back to SUPERVISOR
        for name in worker_names:
            self._graph.add_edge(name, aggregate_name if aggregate_node else supervisor_name)

        if aggregate_node:
            self._graph.add_edge(aggregate_name, supervisor_name)

        self._graph.set_entry_point(supervisor_name)

        return self._graph.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            callback_manager=callback_manager,
        )
Functions
__init__
__init__(state=None, context_manager=None, publisher=None, id_generator=DefaultIDGenerator(), container=None)
Source code in pyagenity/prebuilt/agent/supervisor_team.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def __init__(
    self,
    state: StateT | None = None,
    context_manager: BaseContextManager[StateT] | None = None,
    publisher: BasePublisher | None = None,
    id_generator: BaseIDGenerator = DefaultIDGenerator(),
    container: InjectQ | None = None,
):
    self._graph = StateGraph[StateT](
        state=state,
        context_manager=context_manager,
        publisher=publisher,
        id_generator=id_generator,
        container=container,
    )
compile
compile(supervisor_node, workers, condition, aggregate_node=None, checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())
Source code in pyagenity/prebuilt/agent/supervisor_team.py
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def compile(  # noqa: PLR0912
    self,
    supervisor_node: Callable | tuple[Callable, str],
    workers: dict[str, Callable | ToolNode | tuple[Callable | ToolNode, str]],
    condition: Callable[[AgentState], str],
    aggregate_node: Callable | tuple[Callable, str] | None = None,
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> CompiledGraph:
    # Handle supervisor_node
    if isinstance(supervisor_node, tuple):
        supervisor_func, supervisor_name = supervisor_node
        if not callable(supervisor_func):
            raise ValueError("supervisor_node[0] must be callable")
    else:
        supervisor_func = supervisor_node
        supervisor_name = "SUPERVISOR"
        if not callable(supervisor_func):
            raise ValueError("supervisor_node must be callable")

    if not workers:
        raise ValueError("workers must be a non-empty dict")

    # Add worker nodes
    worker_names = []
    for key, fn in workers.items():
        if isinstance(fn, tuple):
            worker_func, worker_name = fn
            if not (callable(worker_func) or isinstance(worker_func, ToolNode)):
                raise ValueError(f"Worker '{key}'[0] must be callable or ToolNode")
        else:
            worker_func = fn
            worker_name = key
            if not (callable(worker_func) or isinstance(worker_func, ToolNode)):
                raise ValueError(f"Worker '{key}' must be callable or ToolNode")
        self._graph.add_node(worker_name, worker_func)
        worker_names.append(worker_name)

    # Handle aggregate_node
    aggregate_name = "AGGREGATE"
    if aggregate_node:
        if isinstance(aggregate_node, tuple):
            aggregate_func, aggregate_name = aggregate_node
            if not callable(aggregate_func):
                raise ValueError("aggregate_node[0] must be callable")
        else:
            aggregate_func = aggregate_node
            aggregate_name = "AGGREGATE"
            if not callable(aggregate_func):
                raise ValueError("aggregate_node must be callable")
        self._graph.add_node(aggregate_name, aggregate_func)

    # SUPERVISOR decides next worker
    path_map = {k: k for k in worker_names}
    path_map[END] = END
    self._graph.add_conditional_edges(supervisor_name, condition, path_map)

    # After worker, go to AGGREGATE if present, else back to SUPERVISOR
    for name in worker_names:
        self._graph.add_edge(name, aggregate_name if aggregate_node else supervisor_name)

    if aggregate_node:
        self._graph.add_edge(aggregate_name, supervisor_name)

    self._graph.set_entry_point(supervisor_name)

    return self._graph.compile(
        checkpointer=checkpointer,
        store=store,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        callback_manager=callback_manager,
    )
SwarmAgent

Swarm pattern: dispatch to many workers, collect, then reach consensus.

Notes: - The underlying engine executes nodes sequentially; true parallelism isn't performed at the graph level. For concurrency, worker/collector nodes can internally use BackgroundTaskManager or async to fan-out. - This pattern wires a linear broadcast-collect chain ending in CONSENSUS.

Nodes: - optional DISPATCH: prepare/plan the swarm task - WORKER_i: a set of worker nodes (Callable or ToolNode) - optional COLLECT: consolidate each worker's result into shared state - CONSENSUS: aggregate all collected results and produce final output

Methods:

Name Description
__init__
compile
Source code in pyagenity/prebuilt/agent/swarm.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
class SwarmAgent[StateT: AgentState]:
    """Swarm pattern: dispatch to many workers, collect, then reach consensus.

    Notes:
    - The underlying engine executes nodes sequentially; true parallelism isn't
      performed at the graph level. For concurrency, worker/collector nodes can
      internally use BackgroundTaskManager or async to fan-out.
    - This pattern wires a linear broadcast-collect chain ending in CONSENSUS.

    Nodes:
    - optional DISPATCH: prepare/plan the swarm task
    - WORKER_i: a set of worker nodes (Callable or ToolNode)
    - optional COLLECT: consolidate each worker's result into shared state
    - CONSENSUS: aggregate all collected results and produce final output
    """

    def __init__(
        self,
        state: StateT | None = None,
        context_manager: BaseContextManager[StateT] | None = None,
        publisher: BasePublisher | None = None,
        id_generator: BaseIDGenerator = DefaultIDGenerator(),
        container: InjectQ | None = None,
    ):
        self._graph = StateGraph[StateT](
            state=state,
            context_manager=context_manager,
            publisher=publisher,
            id_generator=id_generator,
            container=container,
        )

    def compile(
        self,
        workers: dict[str, Callable | ToolNode | tuple[Callable | ToolNode, str]],
        consensus_node: Callable | tuple[Callable, str],
        options: dict | None = None,
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> CompiledGraph:
        resolved_workers = self._add_worker_nodes(workers)
        worker_sequence = resolved_workers

        options = options or {}
        dispatch_node = options.get("dispatch")
        collect_node = options.get("collect")
        followup_condition = options.get("followup_condition")

        dispatch_name = self._resolve_dispatch(dispatch_node)
        collect_info = self._resolve_collect(collect_node)
        consensus_name = self._resolve_consensus(consensus_node)

        entry = dispatch_name or worker_sequence[0]
        self._graph.set_entry_point(entry)
        if dispatch_name:
            self._graph.add_edge(dispatch_name, worker_sequence[0])

        self._wire_edges(worker_sequence, collect_info, consensus_name)

        if followup_condition is None:

            def _cond(_: AgentState) -> str:
                return END

            followup_condition = _cond

        self._graph.add_conditional_edges(
            consensus_name,
            followup_condition,
            {entry: entry, END: END},
        )

        return self._graph.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            callback_manager=callback_manager,
        )

    # ---- helpers ----
    def _add_worker_nodes(
        self,
        workers: dict[str, Callable | ToolNode | tuple[Callable | ToolNode, str]],
    ) -> list[str]:
        if not workers:
            raise ValueError("workers must be a non-empty dict")

        names: list[str] = []
        for key, fn in workers.items():
            if isinstance(fn, tuple):
                func, name = fn
            else:
                func, name = fn, key
            if not (callable(func) or isinstance(func, ToolNode)):
                raise ValueError(f"Worker '{key}' must be a callable or ToolNode")
            self._graph.add_node(name, func)
            names.append(name)
        return names

    def _resolve_dispatch(self, node: Callable | tuple[Callable, str] | None) -> str | None:
        if not node:
            return None
        if isinstance(node, tuple):
            func, name = node
        else:
            func, name = node, "DISPATCH"
        if not callable(func):
            raise ValueError("dispatch node must be callable")
        self._graph.add_node(name, func)
        return name

    def _resolve_collect(
        self,
        node: Callable | tuple[Callable, str] | None,
    ) -> tuple[Callable, str] | None:
        if not node:
            return None
        if isinstance(node, tuple):
            func, name = node
        else:
            func, name = node, "COLLECT"
        if not callable(func):
            raise ValueError("collect node must be callable")
        # Do not add a single shared collect node to avoid ambiguous routing.
        # We'll create per-worker collect nodes during wiring using this (func, base_name).
        return func, name

    def _resolve_consensus(self, node: Callable | tuple[Callable, str]) -> str:
        if isinstance(node, tuple):
            func, name = node
        else:
            func, name = node, "CONSENSUS"
        if not callable(func):
            raise ValueError("consensus node must be callable")
        self._graph.add_node(name, func)
        return name

    def _wire_edges(
        self,
        worker_sequence: list[str],
        collect_info: tuple[Callable, str] | None,
        consensus_name: str,
    ) -> None:
        for i, wname in enumerate(worker_sequence):
            is_last = i == len(worker_sequence) - 1
            target = consensus_name if is_last else worker_sequence[i + 1]
            if collect_info:
                cfunc, base = collect_info
                cname = f"{base}_{i + 1}"
                # Create a dedicated collect node per worker to prevent loops
                self._graph.add_node(cname, cfunc)
                self._graph.add_edge(wname, cname)
                self._graph.add_edge(cname, target)
            else:
                self._graph.add_edge(wname, target)
Functions
__init__
__init__(state=None, context_manager=None, publisher=None, id_generator=DefaultIDGenerator(), container=None)
Source code in pyagenity/prebuilt/agent/swarm.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def __init__(
    self,
    state: StateT | None = None,
    context_manager: BaseContextManager[StateT] | None = None,
    publisher: BasePublisher | None = None,
    id_generator: BaseIDGenerator = DefaultIDGenerator(),
    container: InjectQ | None = None,
):
    self._graph = StateGraph[StateT](
        state=state,
        context_manager=context_manager,
        publisher=publisher,
        id_generator=id_generator,
        container=container,
    )
compile
compile(workers, consensus_node, options=None, checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())
Source code in pyagenity/prebuilt/agent/swarm.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def compile(
    self,
    workers: dict[str, Callable | ToolNode | tuple[Callable | ToolNode, str]],
    consensus_node: Callable | tuple[Callable, str],
    options: dict | None = None,
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> CompiledGraph:
    resolved_workers = self._add_worker_nodes(workers)
    worker_sequence = resolved_workers

    options = options or {}
    dispatch_node = options.get("dispatch")
    collect_node = options.get("collect")
    followup_condition = options.get("followup_condition")

    dispatch_name = self._resolve_dispatch(dispatch_node)
    collect_info = self._resolve_collect(collect_node)
    consensus_name = self._resolve_consensus(consensus_node)

    entry = dispatch_name or worker_sequence[0]
    self._graph.set_entry_point(entry)
    if dispatch_name:
        self._graph.add_edge(dispatch_name, worker_sequence[0])

    self._wire_edges(worker_sequence, collect_info, consensus_name)

    if followup_condition is None:

        def _cond(_: AgentState) -> str:
            return END

        followup_condition = _cond

    self._graph.add_conditional_edges(
        consensus_name,
        followup_condition,
        {entry: entry, END: END},
    )

    return self._graph.compile(
        checkpointer=checkpointer,
        store=store,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        callback_manager=callback_manager,
    )
Modules
branch_join

Classes:

Name Description
BranchJoinAgent

Execute multiple branches then join.

Attributes:

Name Type Description
StateT
Attributes
StateT module-attribute
StateT = TypeVar('StateT', bound=AgentState)
Classes
BranchJoinAgent

Execute multiple branches then join.

Note: This prebuilt models branches sequentially (not true parallel execution). For each provided branch node, we add edges branch_i -> JOIN. The JOIN node decides whether more branches remain or END. A more advanced version could use BackgroundTaskManager for concurrency.

Methods:

Name Description
__init__
compile
Source code in pyagenity/prebuilt/agent/branch_join.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
class BranchJoinAgent[StateT: AgentState]:
    """Execute multiple branches then join.

    Note: This prebuilt models branches sequentially (not true parallel execution).
    For each provided branch node, we add edges branch_i -> JOIN. The JOIN node
    decides whether more branches remain or END. A more advanced version could
    use BackgroundTaskManager for concurrency.
    """

    def __init__(
        self,
        state: StateT | None = None,
        context_manager: BaseContextManager[StateT] | None = None,
        publisher: BasePublisher | None = None,
        id_generator: BaseIDGenerator = DefaultIDGenerator(),
        container: InjectQ | None = None,
    ):
        self._graph = StateGraph[StateT](
            state=state,
            context_manager=context_manager,
            publisher=publisher,
            id_generator=id_generator,
            container=container,
        )

    def compile(
        self,
        branches: dict[str, Callable | tuple[Callable, str]],
        join_node: Callable | tuple[Callable, str],
        next_branch_condition: Callable | None = None,
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> CompiledGraph:
        if not branches:
            raise ValueError("branches must be a non-empty dict of name -> callable/tuple")

        # Add branch nodes
        branch_names = []
        for key, fn in branches.items():
            if isinstance(fn, tuple):
                branch_func, branch_name = fn
                if not callable(branch_func):
                    raise ValueError(f"Branch '{key}'[0] must be callable")
            else:
                branch_func = fn
                branch_name = key
                if not callable(branch_func):
                    raise ValueError(f"Branch '{key}' must be callable")
            self._graph.add_node(branch_name, branch_func)
            branch_names.append(branch_name)

        # Handle join_node
        if isinstance(join_node, tuple):
            join_func, join_name = join_node
            if not callable(join_func):
                raise ValueError("join_node[0] must be callable")
        else:
            join_func = join_node
            join_name = "JOIN"
            if not callable(join_func):
                raise ValueError("join_node must be callable")
        self._graph.add_node(join_name, join_func)

        # Wire branches to JOIN
        for name in branch_names:
            self._graph.add_edge(name, join_name)

        # Entry: first branch
        first = branch_names[0]
        self._graph.set_entry_point(first)

        # Decide next branch or END after join
        if next_branch_condition is None:
            # default: END after join
            def _cond(_: AgentState) -> str:
                return END

            next_branch_condition = _cond

        # next_branch_condition returns a branch name or END
        path_map = {k: k for k in branch_names}
        path_map[END] = END
        self._graph.add_conditional_edges(join_name, next_branch_condition, path_map)

        return self._graph.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            callback_manager=callback_manager,
        )
Functions
__init__
__init__(state=None, context_manager=None, publisher=None, id_generator=DefaultIDGenerator(), container=None)
Source code in pyagenity/prebuilt/agent/branch_join.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def __init__(
    self,
    state: StateT | None = None,
    context_manager: BaseContextManager[StateT] | None = None,
    publisher: BasePublisher | None = None,
    id_generator: BaseIDGenerator = DefaultIDGenerator(),
    container: InjectQ | None = None,
):
    self._graph = StateGraph[StateT](
        state=state,
        context_manager=context_manager,
        publisher=publisher,
        id_generator=id_generator,
        container=container,
    )
compile
compile(branches, join_node, next_branch_condition=None, checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())
Source code in pyagenity/prebuilt/agent/branch_join.py
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def compile(
    self,
    branches: dict[str, Callable | tuple[Callable, str]],
    join_node: Callable | tuple[Callable, str],
    next_branch_condition: Callable | None = None,
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> CompiledGraph:
    if not branches:
        raise ValueError("branches must be a non-empty dict of name -> callable/tuple")

    # Add branch nodes
    branch_names = []
    for key, fn in branches.items():
        if isinstance(fn, tuple):
            branch_func, branch_name = fn
            if not callable(branch_func):
                raise ValueError(f"Branch '{key}'[0] must be callable")
        else:
            branch_func = fn
            branch_name = key
            if not callable(branch_func):
                raise ValueError(f"Branch '{key}' must be callable")
        self._graph.add_node(branch_name, branch_func)
        branch_names.append(branch_name)

    # Handle join_node
    if isinstance(join_node, tuple):
        join_func, join_name = join_node
        if not callable(join_func):
            raise ValueError("join_node[0] must be callable")
    else:
        join_func = join_node
        join_name = "JOIN"
        if not callable(join_func):
            raise ValueError("join_node must be callable")
    self._graph.add_node(join_name, join_func)

    # Wire branches to JOIN
    for name in branch_names:
        self._graph.add_edge(name, join_name)

    # Entry: first branch
    first = branch_names[0]
    self._graph.set_entry_point(first)

    # Decide next branch or END after join
    if next_branch_condition is None:
        # default: END after join
        def _cond(_: AgentState) -> str:
            return END

        next_branch_condition = _cond

    # next_branch_condition returns a branch name or END
    path_map = {k: k for k in branch_names}
    path_map[END] = END
    self._graph.add_conditional_edges(join_name, next_branch_condition, path_map)

    return self._graph.compile(
        checkpointer=checkpointer,
        store=store,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        callback_manager=callback_manager,
    )
deep_research

Classes:

Name Description
DeepResearchAgent

Deep Research Agent: PLAN → RESEARCH → SYNTHESIZE → CRITIQUE loop.

Attributes:

Name Type Description
StateT
Attributes
StateT module-attribute
StateT = TypeVar('StateT', bound=AgentState)
Classes
DeepResearchAgent

Deep Research Agent: PLAN → RESEARCH → SYNTHESIZE → CRITIQUE loop.

This agent mirrors modern deep-research patterns inspired by DeerFlow and Tongyi DeepResearch: plan tasks, use tools to research, synthesize findings, critique gaps and iterate a bounded number of times.

Nodes: - PLAN: Decompose problem, propose search/tool actions; may include tool calls - RESEARCH: ToolNode executes search/browse/calc/etc tools - SYNTHESIZE: Aggregate and draft a coherent report or partial answer - CRITIQUE: Identify gaps, contradictions, or follow-ups; can request more tools

Routing:
- PLAN -> conditional(_route_after_plan):
    {"RESEARCH": RESEARCH, "SYNTHESIZE": SYNTHESIZE, END: END}
  • RESEARCH -> SYNTHESIZE
  • SYNTHESIZE -> CRITIQUE
  • CRITIQUE -> conditional(_route_after_critique): {"RESEARCH": RESEARCH, END: END}

Iteration Control: - Uses execution_meta.internal_data keys: dr_max_iters (int): maximum critique→research loops (default 2) dr_iters (int): current loop count (auto-updated) dr_heavy_mode (bool): if True, bias towards one more loop when critique suggests

Methods:

Name Description
__init__
compile
Source code in pyagenity/prebuilt/agent/deep_research.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
class DeepResearchAgent[StateT: AgentState]:
    """Deep Research Agent: PLAN → RESEARCH → SYNTHESIZE → CRITIQUE loop.

    This agent mirrors modern deep-research patterns inspired by DeerFlow and
    Tongyi DeepResearch: plan tasks, use tools to research, synthesize findings,
    critique gaps and iterate a bounded number of times.

    Nodes:
    - PLAN: Decompose problem, propose search/tool actions; may include tool calls
    - RESEARCH: ToolNode executes search/browse/calc/etc tools
    - SYNTHESIZE: Aggregate and draft a coherent report or partial answer
    - CRITIQUE: Identify gaps, contradictions, or follow-ups; can request more tools

        Routing:
        - PLAN -> conditional(_route_after_plan):
            {"RESEARCH": RESEARCH, "SYNTHESIZE": SYNTHESIZE, END: END}
    - RESEARCH -> SYNTHESIZE
    - SYNTHESIZE -> CRITIQUE
    - CRITIQUE -> conditional(_route_after_critique): {"RESEARCH": RESEARCH, END: END}

    Iteration Control:
    - Uses execution_meta.internal_data keys:
        dr_max_iters (int): maximum critique→research loops (default 2)
        dr_iters (int): current loop count (auto-updated)
        dr_heavy_mode (bool): if True, bias towards one more loop when critique suggests
    """

    def __init__(
        self,
        state: StateT | None = None,
        context_manager: BaseContextManager[StateT] | None = None,
        publisher: BasePublisher | None = None,
        id_generator: BaseIDGenerator = DefaultIDGenerator(),
        container: InjectQ | None = None,
        max_iters: int = 2,
        heavy_mode: bool = False,
    ):
        # initialize graph
        self._graph = StateGraph[StateT](
            state=state,
            context_manager=context_manager,
            publisher=publisher,
            id_generator=id_generator,
            container=container,
        )
        # seed default internal config on prototype state
        # Note: These values will be copied to new state at invoke time.
        exec_meta: ExecutionState = self._graph._state.execution_meta
        exec_meta.internal_data.setdefault("dr_max_iters", max(0, int(max_iters)))
        exec_meta.internal_data.setdefault("dr_iters", 0)
        exec_meta.internal_data.setdefault("dr_heavy_mode", bool(heavy_mode))

    def compile(  # noqa: PLR0912
        self,
        plan_node: Callable | tuple[Callable, str],
        research_tool_node: ToolNode | tuple[ToolNode, str],
        synthesize_node: Callable | tuple[Callable, str],
        critique_node: Callable | tuple[Callable, str],
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> CompiledGraph:
        # Handle plan_node
        if isinstance(plan_node, tuple):
            plan_func, plan_name = plan_node
            if not callable(plan_func):
                raise ValueError("plan_node[0] must be callable")
        else:
            plan_func = plan_node
            plan_name = "PLAN"
            if not callable(plan_func):
                raise ValueError("plan_node must be callable")

        # Handle research_tool_node
        if isinstance(research_tool_node, tuple):
            research_func, research_name = research_tool_node
            if not isinstance(research_func, ToolNode):
                raise ValueError("research_tool_node[0] must be a ToolNode")
        else:
            research_func = research_tool_node
            research_name = "RESEARCH"
            if not isinstance(research_func, ToolNode):
                raise ValueError("research_tool_node must be a ToolNode")

        # Handle synthesize_node
        if isinstance(synthesize_node, tuple):
            synthesize_func, synthesize_name = synthesize_node
            if not callable(synthesize_func):
                raise ValueError("synthesize_node[0] must be callable")
        else:
            synthesize_func = synthesize_node
            synthesize_name = "SYNTHESIZE"
            if not callable(synthesize_func):
                raise ValueError("synthesize_node must be callable")

        # Handle critique_node
        if isinstance(critique_node, tuple):
            critique_func, critique_name = critique_node
            if not callable(critique_func):
                raise ValueError("critique_node[0] must be callable")
        else:
            critique_func = critique_node
            critique_name = "CRITIQUE"
            if not callable(critique_func):
                raise ValueError("critique_node must be callable")

        # Add nodes
        self._graph.add_node(plan_name, plan_func)
        self._graph.add_node(research_name, research_func)
        self._graph.add_node(synthesize_name, synthesize_func)
        self._graph.add_node(critique_name, critique_func)

        # Edges
        self._graph.add_conditional_edges(
            plan_name,
            _route_after_plan,
            {research_name: research_name, synthesize_name: synthesize_name, END: END},
        )
        self._graph.add_edge(research_name, synthesize_name)
        self._graph.add_edge(synthesize_name, critique_name)
        self._graph.add_conditional_edges(
            critique_name,
            _route_after_critique,
            {research_name: research_name, END: END},
        )

        # Entry
        self._graph.set_entry_point(plan_name)

        return self._graph.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            callback_manager=callback_manager,
        )
Functions
__init__
__init__(state=None, context_manager=None, publisher=None, id_generator=DefaultIDGenerator(), container=None, max_iters=2, heavy_mode=False)
Source code in pyagenity/prebuilt/agent/deep_research.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def __init__(
    self,
    state: StateT | None = None,
    context_manager: BaseContextManager[StateT] | None = None,
    publisher: BasePublisher | None = None,
    id_generator: BaseIDGenerator = DefaultIDGenerator(),
    container: InjectQ | None = None,
    max_iters: int = 2,
    heavy_mode: bool = False,
):
    # initialize graph
    self._graph = StateGraph[StateT](
        state=state,
        context_manager=context_manager,
        publisher=publisher,
        id_generator=id_generator,
        container=container,
    )
    # seed default internal config on prototype state
    # Note: These values will be copied to new state at invoke time.
    exec_meta: ExecutionState = self._graph._state.execution_meta
    exec_meta.internal_data.setdefault("dr_max_iters", max(0, int(max_iters)))
    exec_meta.internal_data.setdefault("dr_iters", 0)
    exec_meta.internal_data.setdefault("dr_heavy_mode", bool(heavy_mode))
compile
compile(plan_node, research_tool_node, synthesize_node, critique_node, checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())
Source code in pyagenity/prebuilt/agent/deep_research.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def compile(  # noqa: PLR0912
    self,
    plan_node: Callable | tuple[Callable, str],
    research_tool_node: ToolNode | tuple[ToolNode, str],
    synthesize_node: Callable | tuple[Callable, str],
    critique_node: Callable | tuple[Callable, str],
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> CompiledGraph:
    # Handle plan_node
    if isinstance(plan_node, tuple):
        plan_func, plan_name = plan_node
        if not callable(plan_func):
            raise ValueError("plan_node[0] must be callable")
    else:
        plan_func = plan_node
        plan_name = "PLAN"
        if not callable(plan_func):
            raise ValueError("plan_node must be callable")

    # Handle research_tool_node
    if isinstance(research_tool_node, tuple):
        research_func, research_name = research_tool_node
        if not isinstance(research_func, ToolNode):
            raise ValueError("research_tool_node[0] must be a ToolNode")
    else:
        research_func = research_tool_node
        research_name = "RESEARCH"
        if not isinstance(research_func, ToolNode):
            raise ValueError("research_tool_node must be a ToolNode")

    # Handle synthesize_node
    if isinstance(synthesize_node, tuple):
        synthesize_func, synthesize_name = synthesize_node
        if not callable(synthesize_func):
            raise ValueError("synthesize_node[0] must be callable")
    else:
        synthesize_func = synthesize_node
        synthesize_name = "SYNTHESIZE"
        if not callable(synthesize_func):
            raise ValueError("synthesize_node must be callable")

    # Handle critique_node
    if isinstance(critique_node, tuple):
        critique_func, critique_name = critique_node
        if not callable(critique_func):
            raise ValueError("critique_node[0] must be callable")
    else:
        critique_func = critique_node
        critique_name = "CRITIQUE"
        if not callable(critique_func):
            raise ValueError("critique_node must be callable")

    # Add nodes
    self._graph.add_node(plan_name, plan_func)
    self._graph.add_node(research_name, research_func)
    self._graph.add_node(synthesize_name, synthesize_func)
    self._graph.add_node(critique_name, critique_func)

    # Edges
    self._graph.add_conditional_edges(
        plan_name,
        _route_after_plan,
        {research_name: research_name, synthesize_name: synthesize_name, END: END},
    )
    self._graph.add_edge(research_name, synthesize_name)
    self._graph.add_edge(synthesize_name, critique_name)
    self._graph.add_conditional_edges(
        critique_name,
        _route_after_critique,
        {research_name: research_name, END: END},
    )

    # Entry
    self._graph.set_entry_point(plan_name)

    return self._graph.compile(
        checkpointer=checkpointer,
        store=store,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        callback_manager=callback_manager,
    )
guarded

Classes:

Name Description
GuardedAgent

Validate output and repair until valid or attempts exhausted.

Attributes:

Name Type Description
StateT
Attributes
StateT module-attribute
StateT = TypeVar('StateT', bound=AgentState)
Classes
GuardedAgent

Validate output and repair until valid or attempts exhausted.

Nodes: - PRODUCE: main generation node - REPAIR: correction node when validation fails

Edges: PRODUCE -> conditional(valid? END : REPAIR) REPAIR -> PRODUCE

Methods:

Name Description
__init__
compile
Source code in pyagenity/prebuilt/agent/guarded.py
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
class GuardedAgent[StateT: AgentState]:
    """Validate output and repair until valid or attempts exhausted.

    Nodes:
    - PRODUCE: main generation node
    - REPAIR: correction node when validation fails

    Edges:
    PRODUCE -> conditional(valid? END : REPAIR)
    REPAIR -> PRODUCE
    """

    def __init__(
        self,
        state: StateT | None = None,
        context_manager: BaseContextManager[StateT] | None = None,
        publisher: BasePublisher | None = None,
        id_generator: BaseIDGenerator = DefaultIDGenerator(),
        container: InjectQ | None = None,
    ):
        self._graph = StateGraph[StateT](
            state=state,
            context_manager=context_manager,
            publisher=publisher,
            id_generator=id_generator,
            container=container,
        )

    def compile(
        self,
        produce_node: Callable | tuple[Callable, str],
        repair_node: Callable | tuple[Callable, str],
        validator: Callable[[AgentState], bool],
        max_attempts: int = 2,
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> CompiledGraph:
        # Handle produce_node
        if isinstance(produce_node, tuple):
            produce_func, produce_name = produce_node
            if not callable(produce_func):
                raise ValueError("produce_node[0] must be callable")
        else:
            produce_func = produce_node
            produce_name = "PRODUCE"
            if not callable(produce_func):
                raise ValueError("produce_node must be callable")

        # Handle repair_node
        if isinstance(repair_node, tuple):
            repair_func, repair_name = repair_node
            if not callable(repair_func):
                raise ValueError("repair_node[0] must be callable")
        else:
            repair_func = repair_node
            repair_name = "REPAIR"
            if not callable(repair_func):
                raise ValueError("repair_node must be callable")

        self._graph.add_node(produce_name, produce_func)
        self._graph.add_node(repair_name, repair_func)

        # produce -> END or REPAIR
        condition = _guard_condition_factory(validator, max_attempts)
        self._graph.add_conditional_edges(
            produce_name,
            condition,
            {repair_name: repair_name, END: END},
        )
        # repair -> produce
        self._graph.add_edge(repair_name, produce_name)

        self._graph.set_entry_point(produce_name)

        return self._graph.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            callback_manager=callback_manager,
        )
Functions
__init__
__init__(state=None, context_manager=None, publisher=None, id_generator=DefaultIDGenerator(), container=None)
Source code in pyagenity/prebuilt/agent/guarded.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def __init__(
    self,
    state: StateT | None = None,
    context_manager: BaseContextManager[StateT] | None = None,
    publisher: BasePublisher | None = None,
    id_generator: BaseIDGenerator = DefaultIDGenerator(),
    container: InjectQ | None = None,
):
    self._graph = StateGraph[StateT](
        state=state,
        context_manager=context_manager,
        publisher=publisher,
        id_generator=id_generator,
        container=container,
    )
compile
compile(produce_node, repair_node, validator, max_attempts=2, checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())
Source code in pyagenity/prebuilt/agent/guarded.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def compile(
    self,
    produce_node: Callable | tuple[Callable, str],
    repair_node: Callable | tuple[Callable, str],
    validator: Callable[[AgentState], bool],
    max_attempts: int = 2,
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> CompiledGraph:
    # Handle produce_node
    if isinstance(produce_node, tuple):
        produce_func, produce_name = produce_node
        if not callable(produce_func):
            raise ValueError("produce_node[0] must be callable")
    else:
        produce_func = produce_node
        produce_name = "PRODUCE"
        if not callable(produce_func):
            raise ValueError("produce_node must be callable")

    # Handle repair_node
    if isinstance(repair_node, tuple):
        repair_func, repair_name = repair_node
        if not callable(repair_func):
            raise ValueError("repair_node[0] must be callable")
    else:
        repair_func = repair_node
        repair_name = "REPAIR"
        if not callable(repair_func):
            raise ValueError("repair_node must be callable")

    self._graph.add_node(produce_name, produce_func)
    self._graph.add_node(repair_name, repair_func)

    # produce -> END or REPAIR
    condition = _guard_condition_factory(validator, max_attempts)
    self._graph.add_conditional_edges(
        produce_name,
        condition,
        {repair_name: repair_name, END: END},
    )
    # repair -> produce
    self._graph.add_edge(repair_name, produce_name)

    self._graph.set_entry_point(produce_name)

    return self._graph.compile(
        checkpointer=checkpointer,
        store=store,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        callback_manager=callback_manager,
    )
map_reduce

Classes:

Name Description
MapReduceAgent

Map over items then reduce.

Attributes:

Name Type Description
StateT
Attributes
StateT module-attribute
StateT = TypeVar('StateT', bound=AgentState)
Classes
MapReduceAgent

Map over items then reduce.

Nodes: - SPLIT: optional, prepares per-item tasks (or state already contains items) - MAP: processes one item per iteration - REDUCE: aggregates results and decides END or continue

Compile requires

map_node: Callable|ToolNode reduce_node: Callable split_node: Callable | None condition: Callable[[AgentState], str] returns "MAP" to continue or END

Methods:

Name Description
__init__
compile
Source code in pyagenity/prebuilt/agent/map_reduce.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
class MapReduceAgent[StateT: AgentState]:
    """Map over items then reduce.

    Nodes:
    - SPLIT: optional, prepares per-item tasks (or state already contains items)
    - MAP: processes one item per iteration
    - REDUCE: aggregates results and decides END or continue

    Compile requires:
      map_node: Callable|ToolNode
      reduce_node: Callable
      split_node: Callable | None
      condition: Callable[[AgentState], str] returns "MAP" to continue or END
    """

    def __init__(
        self,
        state: StateT | None = None,
        context_manager: BaseContextManager[StateT] | None = None,
        publisher: BasePublisher | None = None,
        id_generator: BaseIDGenerator = DefaultIDGenerator(),
        container: InjectQ | None = None,
    ):
        self._graph = StateGraph[StateT](
            state=state,
            context_manager=context_manager,
            publisher=publisher,
            id_generator=id_generator,
            container=container,
        )

    def compile(  # noqa: PLR0912
        self,
        map_node: Callable | ToolNode | tuple[Callable | ToolNode, str],
        reduce_node: Callable | tuple[Callable, str],
        split_node: Callable | tuple[Callable, str] | None = None,
        condition: Callable[[AgentState], str] | None = None,
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> CompiledGraph:
        # Handle split_node
        split_name = "SPLIT"
        if split_node:
            if isinstance(split_node, tuple):
                split_func, split_name = split_node
                if not callable(split_func):
                    raise ValueError("split_node[0] must be callable")
            else:
                split_func = split_node
                split_name = "SPLIT"
                if not callable(split_func):
                    raise ValueError("split_node must be callable")
            self._graph.add_node(split_name, split_func)

        # Handle map_node
        if isinstance(map_node, tuple):
            map_func, map_name = map_node
            if not (callable(map_func) or isinstance(map_func, ToolNode)):
                raise ValueError("map_node[0] must be callable or ToolNode")
        else:
            map_func = map_node
            map_name = "MAP"
            if not (callable(map_func) or isinstance(map_func, ToolNode)):
                raise ValueError("map_node must be callable or ToolNode")
        self._graph.add_node(map_name, map_func)

        # Handle reduce_node
        if isinstance(reduce_node, tuple):
            reduce_func, reduce_name = reduce_node
            if not callable(reduce_func):
                raise ValueError("reduce_node[0] must be callable")
        else:
            reduce_func = reduce_node
            reduce_name = "REDUCE"
            if not callable(reduce_func):
                raise ValueError("reduce_node must be callable")
        self._graph.add_node(reduce_name, reduce_func)

        # Edges
        if split_node:
            self._graph.add_edge(split_name, map_name)
            self._graph.set_entry_point(split_name)
        else:
            self._graph.set_entry_point(map_name)

        self._graph.add_edge(map_name, reduce_name)

        # Continue mapping or finish
        if condition is None:
            # default: finish after one map-reduce
            def _cond(_: AgentState) -> str:
                return END

            condition = _cond

        self._graph.add_conditional_edges(reduce_name, condition, {map_name: map_name, END: END})

        return self._graph.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            callback_manager=callback_manager,
        )
Functions
__init__
__init__(state=None, context_manager=None, publisher=None, id_generator=DefaultIDGenerator(), container=None)
Source code in pyagenity/prebuilt/agent/map_reduce.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def __init__(
    self,
    state: StateT | None = None,
    context_manager: BaseContextManager[StateT] | None = None,
    publisher: BasePublisher | None = None,
    id_generator: BaseIDGenerator = DefaultIDGenerator(),
    container: InjectQ | None = None,
):
    self._graph = StateGraph[StateT](
        state=state,
        context_manager=context_manager,
        publisher=publisher,
        id_generator=id_generator,
        container=container,
    )
compile
compile(map_node, reduce_node, split_node=None, condition=None, checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())
Source code in pyagenity/prebuilt/agent/map_reduce.py
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def compile(  # noqa: PLR0912
    self,
    map_node: Callable | ToolNode | tuple[Callable | ToolNode, str],
    reduce_node: Callable | tuple[Callable, str],
    split_node: Callable | tuple[Callable, str] | None = None,
    condition: Callable[[AgentState], str] | None = None,
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> CompiledGraph:
    # Handle split_node
    split_name = "SPLIT"
    if split_node:
        if isinstance(split_node, tuple):
            split_func, split_name = split_node
            if not callable(split_func):
                raise ValueError("split_node[0] must be callable")
        else:
            split_func = split_node
            split_name = "SPLIT"
            if not callable(split_func):
                raise ValueError("split_node must be callable")
        self._graph.add_node(split_name, split_func)

    # Handle map_node
    if isinstance(map_node, tuple):
        map_func, map_name = map_node
        if not (callable(map_func) or isinstance(map_func, ToolNode)):
            raise ValueError("map_node[0] must be callable or ToolNode")
    else:
        map_func = map_node
        map_name = "MAP"
        if not (callable(map_func) or isinstance(map_func, ToolNode)):
            raise ValueError("map_node must be callable or ToolNode")
    self._graph.add_node(map_name, map_func)

    # Handle reduce_node
    if isinstance(reduce_node, tuple):
        reduce_func, reduce_name = reduce_node
        if not callable(reduce_func):
            raise ValueError("reduce_node[0] must be callable")
    else:
        reduce_func = reduce_node
        reduce_name = "REDUCE"
        if not callable(reduce_func):
            raise ValueError("reduce_node must be callable")
    self._graph.add_node(reduce_name, reduce_func)

    # Edges
    if split_node:
        self._graph.add_edge(split_name, map_name)
        self._graph.set_entry_point(split_name)
    else:
        self._graph.set_entry_point(map_name)

    self._graph.add_edge(map_name, reduce_name)

    # Continue mapping or finish
    if condition is None:
        # default: finish after one map-reduce
        def _cond(_: AgentState) -> str:
            return END

        condition = _cond

    self._graph.add_conditional_edges(reduce_name, condition, {map_name: map_name, END: END})

    return self._graph.compile(
        checkpointer=checkpointer,
        store=store,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        callback_manager=callback_manager,
    )
network

Classes:

Name Description
NetworkAgent

Network pattern: define arbitrary node set and routing policies.

Attributes:

Name Type Description
StateT
Attributes
StateT module-attribute
StateT = TypeVar('StateT', bound=AgentState)
Classes
NetworkAgent

Network pattern: define arbitrary node set and routing policies.

  • Nodes can be callables or ToolNode.
  • Edges can be static or conditional via a router function per node.
  • Entry point is explicit.

Methods:

Name Description
__init__
compile
Source code in pyagenity/prebuilt/agent/network.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
class NetworkAgent[StateT: AgentState]:
    """Network pattern: define arbitrary node set and routing policies.

    - Nodes can be callables or ToolNode.
    - Edges can be static or conditional via a router function per node.
    - Entry point is explicit.
    """

    def __init__(
        self,
        state: StateT | None = None,
        context_manager: BaseContextManager[StateT] | None = None,
        publisher: BasePublisher | None = None,
        id_generator: BaseIDGenerator = DefaultIDGenerator(),
        container: InjectQ | None = None,
    ):
        self._graph = StateGraph[StateT](
            state=state,
            context_manager=context_manager,
            publisher=publisher,
            id_generator=id_generator,
            container=container,
        )

    def compile(
        self,
        nodes: dict[str, Callable | ToolNode | tuple[Callable | ToolNode, str]],
        entry: str,
        static_edges: list[tuple[str, str]] | None = None,
        conditional_edges: list[tuple[str, Callable[[AgentState], str], dict[str, str]]]
        | None = None,
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> CompiledGraph:
        if not nodes:
            raise ValueError("nodes must be a non-empty dict")

        # Add nodes
        for key, fn in nodes.items():
            if isinstance(fn, tuple):
                func, name = fn
            else:
                func, name = fn, key
            if not (callable(func) or isinstance(func, ToolNode)):
                raise ValueError(f"Node '{key}' must be a callable or ToolNode")
            self._graph.add_node(name, func)

        if entry not in self._graph.nodes:
            raise ValueError(f"entry node '{entry}' must be present in nodes")

        # Static edges
        for src, dst in static_edges or []:
            if src not in self._graph.nodes or dst not in self._graph.nodes:
                raise ValueError(f"Invalid static edge {src}->{dst}: unknown node")
            self._graph.add_edge(src, dst)

        # Conditional edges
        for src, cond, pmap in conditional_edges or []:
            if src not in self._graph.nodes:
                raise ValueError(f"Invalid conditional edge: unknown node '{src}'")
            self._graph.add_conditional_edges(src, cond, pmap)

        # Note: callers may include END in path maps; not enforced here.

        self._graph.set_entry_point(entry)

        return self._graph.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            callback_manager=callback_manager,
        )
Functions
__init__
__init__(state=None, context_manager=None, publisher=None, id_generator=DefaultIDGenerator(), container=None)
Source code in pyagenity/prebuilt/agent/network.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def __init__(
    self,
    state: StateT | None = None,
    context_manager: BaseContextManager[StateT] | None = None,
    publisher: BasePublisher | None = None,
    id_generator: BaseIDGenerator = DefaultIDGenerator(),
    container: InjectQ | None = None,
):
    self._graph = StateGraph[StateT](
        state=state,
        context_manager=context_manager,
        publisher=publisher,
        id_generator=id_generator,
        container=container,
    )
compile
compile(nodes, entry, static_edges=None, conditional_edges=None, checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())
Source code in pyagenity/prebuilt/agent/network.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def compile(
    self,
    nodes: dict[str, Callable | ToolNode | tuple[Callable | ToolNode, str]],
    entry: str,
    static_edges: list[tuple[str, str]] | None = None,
    conditional_edges: list[tuple[str, Callable[[AgentState], str], dict[str, str]]]
    | None = None,
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> CompiledGraph:
    if not nodes:
        raise ValueError("nodes must be a non-empty dict")

    # Add nodes
    for key, fn in nodes.items():
        if isinstance(fn, tuple):
            func, name = fn
        else:
            func, name = fn, key
        if not (callable(func) or isinstance(func, ToolNode)):
            raise ValueError(f"Node '{key}' must be a callable or ToolNode")
        self._graph.add_node(name, func)

    if entry not in self._graph.nodes:
        raise ValueError(f"entry node '{entry}' must be present in nodes")

    # Static edges
    for src, dst in static_edges or []:
        if src not in self._graph.nodes or dst not in self._graph.nodes:
            raise ValueError(f"Invalid static edge {src}->{dst}: unknown node")
        self._graph.add_edge(src, dst)

    # Conditional edges
    for src, cond, pmap in conditional_edges or []:
        if src not in self._graph.nodes:
            raise ValueError(f"Invalid conditional edge: unknown node '{src}'")
        self._graph.add_conditional_edges(src, cond, pmap)

    # Note: callers may include END in path maps; not enforced here.

    self._graph.set_entry_point(entry)

    return self._graph.compile(
        checkpointer=checkpointer,
        store=store,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        callback_manager=callback_manager,
    )
plan_act_reflect

Classes:

Name Description
PlanActReflectAgent

Plan -> Act -> Reflect looping agent.

Attributes:

Name Type Description
StateT
Attributes
StateT module-attribute
StateT = TypeVar('StateT', bound=AgentState)
Classes
PlanActReflectAgent

Plan -> Act -> Reflect looping agent.

Pattern

PLAN -> (condition) -> ACT | REFLECT | END ACT -> REFLECT REFLECT -> PLAN

Default condition (_should_act): - If last assistant message contains tool calls -> ACT - If last message is from a tool -> REFLECT - Else -> END

Provide a custom condition to override this heuristic and implement
  • Budget / depth limiting
  • Confidence-based early stop
  • Dynamic branch selection (e.g., different tool nodes)

Parameters (constructor): state: Optional initial state instance context_manager: Custom context manager publisher: Optional publisher for streaming / events id_generator: ID generation strategy container: InjectQ DI container

compile(...) arguments: plan_node: Callable (state -> state). Produces next thought / tool requests tool_node: ToolNode executing declared tools reflect_node: Callable (state -> state). Consumes tool results & may adjust plan condition: Optional Callable[[AgentState], str] returning next node name or END checkpointer/store/interrupt_before/interrupt_after/callback_manager: Standard graph compilation options

Returns:

Type Description

CompiledGraph ready for invoke / ainvoke.

Notes
  • Node names can be customized via (callable, "NAME") tuples.
  • condition must return one of: tool_node_name, reflect_node_name, END.

Methods:

Name Description
__init__
compile

Compile the Plan-Act-Reflect loop.

Source code in pyagenity/prebuilt/agent/plan_act_reflect.py
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
class PlanActReflectAgent[StateT: AgentState]:
    """Plan -> Act -> Reflect looping agent.

    Pattern:
        PLAN -> (condition) -> ACT | REFLECT | END
        ACT -> REFLECT
        REFLECT -> PLAN

    Default condition (_should_act):
        - If last assistant message contains tool calls -> ACT
        - If last message is from a tool -> REFLECT
        - Else -> END

    Provide a custom condition to override this heuristic and implement:
        * Budget / depth limiting
        * Confidence-based early stop
        * Dynamic branch selection (e.g., different tool nodes)

    Parameters (constructor):
        state: Optional initial state instance
        context_manager: Custom context manager
        publisher: Optional publisher for streaming / events
        id_generator: ID generation strategy
        container: InjectQ DI container

    compile(...) arguments:
        plan_node: Callable (state -> state). Produces next thought / tool requests
        tool_node: ToolNode executing declared tools
        reflect_node: Callable (state -> state). Consumes tool results & may adjust plan
        condition: Optional Callable[[AgentState], str] returning next node name or END
        checkpointer/store/interrupt_before/interrupt_after/callback_manager:
            Standard graph compilation options

    Returns:
        CompiledGraph ready for invoke / ainvoke.

    Notes:
        - Node names can be customized via (callable, "NAME") tuples.
        - condition must return one of: tool_node_name, reflect_node_name, END.
    """

    def __init__(
        self,
        state: StateT | None = None,
        context_manager: BaseContextManager[StateT] | None = None,
        publisher: BasePublisher | None = None,
        id_generator: BaseIDGenerator = DefaultIDGenerator(),
        container: InjectQ | None = None,
    ):
        self._graph = StateGraph[StateT](
            state=state,
            context_manager=context_manager,
            publisher=publisher,
            id_generator=id_generator,
            container=container,
        )

    def compile(
        self,
        plan_node: Callable | tuple[Callable, str],
        tool_node: ToolNode | tuple[ToolNode, str],
        reflect_node: Callable | tuple[Callable, str],
        *,
        condition: Callable[[AgentState], str] | None = None,
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> CompiledGraph:
        """Compile the Plan-Act-Reflect loop.

        Args:
            plan_node: Callable or (callable, name)
            tool_node: ToolNode or (ToolNode, name)
            reflect_node: Callable or (callable, name)
            condition: Optional decision function. Defaults to internal heuristic.
            checkpointer/store/interrupt_* / callback_manager: Standard graph options.

        Returns:
            CompiledGraph
        """
        # PLAN
        if isinstance(plan_node, tuple):
            plan_func, plan_name = plan_node
            if not callable(plan_func):
                raise ValueError("plan_node[0] must be callable")
        else:
            plan_func = plan_node
            plan_name = "PLAN"
            if not callable(plan_func):
                raise ValueError("plan_node must be callable")

        # ACT
        if isinstance(tool_node, tuple):
            tool_func, tool_name = tool_node
            if not isinstance(tool_func, ToolNode):
                raise ValueError("tool_node[0] must be a ToolNode")
        else:
            tool_func = tool_node
            tool_name = "ACT"
            if not isinstance(tool_func, ToolNode):
                raise ValueError("tool_node must be a ToolNode")

        # REFLECT
        if isinstance(reflect_node, tuple):
            reflect_func, reflect_name = reflect_node
            if not callable(reflect_func):
                raise ValueError("reflect_node[0] must be callable")
        else:
            reflect_func = reflect_node
            reflect_name = "REFLECT"
            if not callable(reflect_func):
                raise ValueError("reflect_node must be callable")

        # Register nodes
        self._graph.add_node(plan_name, plan_func)
        self._graph.add_node(tool_name, tool_func)
        self._graph.add_node(reflect_name, reflect_func)

        # Decision
        decision_fn = condition or _should_act
        self._graph.add_conditional_edges(
            plan_name,
            decision_fn,
            {tool_name: tool_name, reflect_name: reflect_name, END: END},
        )

        # Loop edges
        self._graph.add_edge(tool_name, reflect_name)
        self._graph.add_edge(reflect_name, plan_name)

        # Entry
        self._graph.set_entry_point(plan_name)

        return self._graph.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            callback_manager=callback_manager,
        )
Functions
__init__
__init__(state=None, context_manager=None, publisher=None, id_generator=DefaultIDGenerator(), container=None)
Source code in pyagenity/prebuilt/agent/plan_act_reflect.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def __init__(
    self,
    state: StateT | None = None,
    context_manager: BaseContextManager[StateT] | None = None,
    publisher: BasePublisher | None = None,
    id_generator: BaseIDGenerator = DefaultIDGenerator(),
    container: InjectQ | None = None,
):
    self._graph = StateGraph[StateT](
        state=state,
        context_manager=context_manager,
        publisher=publisher,
        id_generator=id_generator,
        container=container,
    )
compile
compile(plan_node, tool_node, reflect_node, *, condition=None, checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())

Compile the Plan-Act-Reflect loop.

Parameters:

Name Type Description Default
plan_node Callable | tuple[Callable, str]

Callable or (callable, name)

required
tool_node ToolNode | tuple[ToolNode, str]

ToolNode or (ToolNode, name)

required
reflect_node Callable | tuple[Callable, str]

Callable or (callable, name)

required
condition Callable[[AgentState], str] | None

Optional decision function. Defaults to internal heuristic.

None
checkpointer/store/interrupt_* / callback_manager

Standard graph options.

required

Returns:

Type Description
CompiledGraph

CompiledGraph

Source code in pyagenity/prebuilt/agent/plan_act_reflect.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def compile(
    self,
    plan_node: Callable | tuple[Callable, str],
    tool_node: ToolNode | tuple[ToolNode, str],
    reflect_node: Callable | tuple[Callable, str],
    *,
    condition: Callable[[AgentState], str] | None = None,
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> CompiledGraph:
    """Compile the Plan-Act-Reflect loop.

    Args:
        plan_node: Callable or (callable, name)
        tool_node: ToolNode or (ToolNode, name)
        reflect_node: Callable or (callable, name)
        condition: Optional decision function. Defaults to internal heuristic.
        checkpointer/store/interrupt_* / callback_manager: Standard graph options.

    Returns:
        CompiledGraph
    """
    # PLAN
    if isinstance(plan_node, tuple):
        plan_func, plan_name = plan_node
        if not callable(plan_func):
            raise ValueError("plan_node[0] must be callable")
    else:
        plan_func = plan_node
        plan_name = "PLAN"
        if not callable(plan_func):
            raise ValueError("plan_node must be callable")

    # ACT
    if isinstance(tool_node, tuple):
        tool_func, tool_name = tool_node
        if not isinstance(tool_func, ToolNode):
            raise ValueError("tool_node[0] must be a ToolNode")
    else:
        tool_func = tool_node
        tool_name = "ACT"
        if not isinstance(tool_func, ToolNode):
            raise ValueError("tool_node must be a ToolNode")

    # REFLECT
    if isinstance(reflect_node, tuple):
        reflect_func, reflect_name = reflect_node
        if not callable(reflect_func):
            raise ValueError("reflect_node[0] must be callable")
    else:
        reflect_func = reflect_node
        reflect_name = "REFLECT"
        if not callable(reflect_func):
            raise ValueError("reflect_node must be callable")

    # Register nodes
    self._graph.add_node(plan_name, plan_func)
    self._graph.add_node(tool_name, tool_func)
    self._graph.add_node(reflect_name, reflect_func)

    # Decision
    decision_fn = condition or _should_act
    self._graph.add_conditional_edges(
        plan_name,
        decision_fn,
        {tool_name: tool_name, reflect_name: reflect_name, END: END},
    )

    # Loop edges
    self._graph.add_edge(tool_name, reflect_name)
    self._graph.add_edge(reflect_name, plan_name)

    # Entry
    self._graph.set_entry_point(plan_name)

    return self._graph.compile(
        checkpointer=checkpointer,
        store=store,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        callback_manager=callback_manager,
    )
rag

Classes:

Name Description
RAGAgent

Simple RAG: retrieve -> synthesize; optional follow-up.

Attributes:

Name Type Description
StateT
Attributes
StateT module-attribute
StateT = TypeVar('StateT', bound=AgentState)
Classes
RAGAgent

Simple RAG: retrieve -> synthesize; optional follow-up.

Nodes: - RETRIEVE: uses a retriever (callable or ToolNode) to fetch context - SYNTHESIZE: LLM/composer builds an answer - Optional condition: loop back to RETRIEVE for follow-up queries; else END

Methods:

Name Description
__init__
compile
compile_advanced

Advanced RAG wiring with hybrid retrieval and optional stages.

Source code in pyagenity/prebuilt/agent/rag.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
class RAGAgent[StateT: AgentState]:
    """Simple RAG: retrieve -> synthesize; optional follow-up.

    Nodes:
    - RETRIEVE: uses a retriever (callable or ToolNode) to fetch context
    - SYNTHESIZE: LLM/composer builds an answer
    - Optional condition: loop back to RETRIEVE for follow-up queries; else END
    """

    def __init__(
        self,
        state: StateT | None = None,
        context_manager: BaseContextManager[StateT] | None = None,
        publisher: BasePublisher | None = None,
        id_generator: BaseIDGenerator = DefaultIDGenerator(),
        container: InjectQ | None = None,
    ):
        self._graph = StateGraph[StateT](
            state=state,
            context_manager=context_manager,
            publisher=publisher,
            id_generator=id_generator,
            container=container,
        )

    def compile(
        self,
        retriever_node: Callable | ToolNode | tuple[Callable | ToolNode, str],
        synthesize_node: Callable | tuple[Callable, str],
        followup_condition: Callable[[AgentState], str] | None = None,
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> CompiledGraph:
        # Nodes
        # Handle retriever_node
        if isinstance(retriever_node, tuple):
            retriever_func, retriever_name = retriever_node
            if not (callable(retriever_func) or isinstance(retriever_func, ToolNode)):
                raise ValueError("retriever_node[0] must be callable or ToolNode")
        else:
            retriever_func = retriever_node
            retriever_name = "RETRIEVE"
            if not (callable(retriever_func) or isinstance(retriever_func, ToolNode)):
                raise ValueError("retriever_node must be callable or ToolNode")

        # Handle synthesize_node
        if isinstance(synthesize_node, tuple):
            synthesize_func, synthesize_name = synthesize_node
            if not callable(synthesize_func):
                raise ValueError("synthesize_node[0] must be callable")
        else:
            synthesize_func = synthesize_node
            synthesize_name = "SYNTHESIZE"
            if not callable(synthesize_func):
                raise ValueError("synthesize_node must be callable")

        self._graph.add_node(retriever_name, retriever_func)  # type: ignore[arg-type]
        self._graph.add_node(synthesize_name, synthesize_func)

        # Edges
        self._graph.add_edge(retriever_name, synthesize_name)
        self._graph.set_entry_point(retriever_name)

        if followup_condition is None:
            # default: END after synthesize
            def _cond(_: AgentState) -> str:
                return END

            followup_condition = _cond

        self._graph.add_conditional_edges(
            synthesize_name,
            followup_condition,
            {retriever_name: retriever_name, END: END},
        )

        return self._graph.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            callback_manager=callback_manager,
        )

    def compile_advanced(
        self,
        retriever_nodes: list[Callable | ToolNode | tuple[Callable | ToolNode, str]],
        synthesize_node: Callable | tuple[Callable, str],
        options: dict | None = None,
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> CompiledGraph:
        """Advanced RAG wiring with hybrid retrieval and optional stages.

        Chain:
          (QUERY_PLAN?) -> R1 -> (MERGE?) -> R2 -> (MERGE?) -> ...
          -> (RERANK?) -> (COMPRESS?) -> SYNTHESIZE -> cond
        Each retriever may be a different modality (sparse, dense, self-query, MMR, etc.).
        """

        options = options or {}
        query_plan_node = options.get("query_plan")
        merger_node = options.get("merge")
        rerank_node = options.get("rerank")
        compress_node = options.get("compress")
        followup_condition = options.get("followup_condition")

        qname = self._add_optional_node(
            query_plan_node,
            default_name="QUERY_PLAN",
            label="query_plan",
        )

        # Add retrievers
        r_names = self._add_retriever_nodes(retriever_nodes)

        # Optional stages
        mname = self._add_optional_node(merger_node, default_name="MERGE", label="merge")
        rrname = self._add_optional_node(rerank_node, default_name="RERANK", label="rerank")
        cname = self._add_optional_node(
            compress_node,
            default_name="COMPRESS",
            label="compress",
        )

        # Synthesize
        sname = self._add_synthesize_node(synthesize_node)

        # Wire edges end-to-end and follow-up
        self._wire_advanced_edges(qname, r_names, mname, rrname, cname, sname, followup_condition)

        return self._graph.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            callback_manager=callback_manager,
        )

    # ---- helpers ----
    def _add_optional_node(
        self,
        node: Callable | tuple[Callable, str] | None,
        *,
        default_name: str,
        label: str,
    ) -> str | None:
        if not node:
            return None
        if isinstance(node, tuple):
            func, name = node
        else:
            func, name = node, default_name
        if not callable(func):
            raise ValueError(f"{label} node must be callable")
        self._graph.add_node(name, func)
        return name

    def _add_retriever_nodes(
        self,
        retriever_nodes: list[Callable | ToolNode | tuple[Callable | ToolNode, str]],
    ) -> list[str]:
        if not retriever_nodes:
            raise ValueError("retriever_nodes must be non-empty")
        names: list[str] = []
        for idx, rn in enumerate(retriever_nodes):
            if isinstance(rn, tuple):
                rfunc, rname = rn
            else:
                rfunc, rname = rn, f"RETRIEVE_{idx + 1}"
            if not (callable(rfunc) or isinstance(rfunc, ToolNode)):
                raise ValueError("retriever must be callable or ToolNode")
            self._graph.add_node(rname, rfunc)  # type: ignore[arg-type]
            names.append(rname)
        return names

    def _add_synthesize_node(self, synthesize_node: Callable | tuple[Callable, str]) -> str:
        if isinstance(synthesize_node, tuple):
            sfunc, sname = synthesize_node
        else:
            sfunc, sname = synthesize_node, "SYNTHESIZE"
        if not callable(sfunc):
            raise ValueError("synthesize_node must be callable")
        self._graph.add_node(sname, sfunc)
        return sname

    def _wire_advanced_edges(
        self,
        qname: str | None,
        r_names: list[str],
        mname: str | None,
        rrname: str | None,
        cname: str | None,
        sname: str,
        followup_condition: Callable[[AgentState], str] | None = None,
    ) -> None:
        entry = qname or r_names[0]
        self._graph.set_entry_point(entry)
        if qname:
            self._graph.add_edge(qname, r_names[0])

        tail_target = rrname or cname or sname
        for i, rname in enumerate(r_names):
            is_last = i == len(r_names) - 1
            nxt = r_names[i + 1] if not is_last else tail_target
            if mname:
                self._graph.add_edge(rname, mname)
                self._graph.add_edge(mname, nxt)
            else:
                self._graph.add_edge(rname, nxt)

        # Tail wiring
        if rrname and cname:
            self._graph.add_edge(rrname, cname)
            self._graph.add_edge(cname, sname)
        elif rrname:
            self._graph.add_edge(rrname, sname)
        elif cname:
            self._graph.add_edge(cname, sname)

        # default follow-up to END
        if followup_condition is None:

            def _cond(_: AgentState) -> str:
                return END

            followup_condition = _cond

        entry_node = qname or r_names[0]
        path_map = {entry_node: entry_node, END: END}
        self._graph.add_conditional_edges(sname, followup_condition, path_map)
Functions
__init__
__init__(state=None, context_manager=None, publisher=None, id_generator=DefaultIDGenerator(), container=None)
Source code in pyagenity/prebuilt/agent/rag.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def __init__(
    self,
    state: StateT | None = None,
    context_manager: BaseContextManager[StateT] | None = None,
    publisher: BasePublisher | None = None,
    id_generator: BaseIDGenerator = DefaultIDGenerator(),
    container: InjectQ | None = None,
):
    self._graph = StateGraph[StateT](
        state=state,
        context_manager=context_manager,
        publisher=publisher,
        id_generator=id_generator,
        container=container,
    )
compile
compile(retriever_node, synthesize_node, followup_condition=None, checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())
Source code in pyagenity/prebuilt/agent/rag.py
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def compile(
    self,
    retriever_node: Callable | ToolNode | tuple[Callable | ToolNode, str],
    synthesize_node: Callable | tuple[Callable, str],
    followup_condition: Callable[[AgentState], str] | None = None,
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> CompiledGraph:
    # Nodes
    # Handle retriever_node
    if isinstance(retriever_node, tuple):
        retriever_func, retriever_name = retriever_node
        if not (callable(retriever_func) or isinstance(retriever_func, ToolNode)):
            raise ValueError("retriever_node[0] must be callable or ToolNode")
    else:
        retriever_func = retriever_node
        retriever_name = "RETRIEVE"
        if not (callable(retriever_func) or isinstance(retriever_func, ToolNode)):
            raise ValueError("retriever_node must be callable or ToolNode")

    # Handle synthesize_node
    if isinstance(synthesize_node, tuple):
        synthesize_func, synthesize_name = synthesize_node
        if not callable(synthesize_func):
            raise ValueError("synthesize_node[0] must be callable")
    else:
        synthesize_func = synthesize_node
        synthesize_name = "SYNTHESIZE"
        if not callable(synthesize_func):
            raise ValueError("synthesize_node must be callable")

    self._graph.add_node(retriever_name, retriever_func)  # type: ignore[arg-type]
    self._graph.add_node(synthesize_name, synthesize_func)

    # Edges
    self._graph.add_edge(retriever_name, synthesize_name)
    self._graph.set_entry_point(retriever_name)

    if followup_condition is None:
        # default: END after synthesize
        def _cond(_: AgentState) -> str:
            return END

        followup_condition = _cond

    self._graph.add_conditional_edges(
        synthesize_name,
        followup_condition,
        {retriever_name: retriever_name, END: END},
    )

    return self._graph.compile(
        checkpointer=checkpointer,
        store=store,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        callback_manager=callback_manager,
    )
compile_advanced
compile_advanced(retriever_nodes, synthesize_node, options=None, checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())

Advanced RAG wiring with hybrid retrieval and optional stages.

Chain

(QUERY_PLAN?) -> R1 -> (MERGE?) -> R2 -> (MERGE?) -> ... -> (RERANK?) -> (COMPRESS?) -> SYNTHESIZE -> cond

Each retriever may be a different modality (sparse, dense, self-query, MMR, etc.).

Source code in pyagenity/prebuilt/agent/rag.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def compile_advanced(
    self,
    retriever_nodes: list[Callable | ToolNode | tuple[Callable | ToolNode, str]],
    synthesize_node: Callable | tuple[Callable, str],
    options: dict | None = None,
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> CompiledGraph:
    """Advanced RAG wiring with hybrid retrieval and optional stages.

    Chain:
      (QUERY_PLAN?) -> R1 -> (MERGE?) -> R2 -> (MERGE?) -> ...
      -> (RERANK?) -> (COMPRESS?) -> SYNTHESIZE -> cond
    Each retriever may be a different modality (sparse, dense, self-query, MMR, etc.).
    """

    options = options or {}
    query_plan_node = options.get("query_plan")
    merger_node = options.get("merge")
    rerank_node = options.get("rerank")
    compress_node = options.get("compress")
    followup_condition = options.get("followup_condition")

    qname = self._add_optional_node(
        query_plan_node,
        default_name="QUERY_PLAN",
        label="query_plan",
    )

    # Add retrievers
    r_names = self._add_retriever_nodes(retriever_nodes)

    # Optional stages
    mname = self._add_optional_node(merger_node, default_name="MERGE", label="merge")
    rrname = self._add_optional_node(rerank_node, default_name="RERANK", label="rerank")
    cname = self._add_optional_node(
        compress_node,
        default_name="COMPRESS",
        label="compress",
    )

    # Synthesize
    sname = self._add_synthesize_node(synthesize_node)

    # Wire edges end-to-end and follow-up
    self._wire_advanced_edges(qname, r_names, mname, rrname, cname, sname, followup_condition)

    return self._graph.compile(
        checkpointer=checkpointer,
        store=store,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        callback_manager=callback_manager,
    )
react

Classes:

Name Description
ReactAgent

Attributes:

Name Type Description
StateT
Attributes
StateT module-attribute
StateT = TypeVar('StateT', bound=AgentState)
Classes
ReactAgent

Methods:

Name Description
__init__
compile
Source code in pyagenity/prebuilt/agent/react.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
class ReactAgent[StateT: AgentState]:
    def __init__(
        self,
        state: StateT | None = None,
        context_manager: BaseContextManager[StateT] | None = None,
        publisher: BasePublisher | None = None,
        id_generator: BaseIDGenerator = DefaultIDGenerator(),
        container: InjectQ | None = None,
    ):
        self._graph = StateGraph[StateT](
            state=state,
            context_manager=context_manager,
            publisher=publisher,
            id_generator=id_generator,
            container=container,
        )

    def compile(
        self,
        main_node: tuple[Callable, str] | Callable,
        tool_node: tuple[Callable, str] | Callable,
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> CompiledGraph:
        # Determine main node function and name
        if isinstance(main_node, tuple):
            main_func, main_name = main_node
            if not callable(main_func):
                raise ValueError("main_node[0] must be a callable function")
        else:
            main_func = main_node
            main_name = "MAIN"
            if not callable(main_func):
                raise ValueError("main_node must be a callable function")

        # Determine tool node function and name
        if isinstance(tool_node, tuple):
            tool_func, tool_name = tool_node
            # Accept both callable functions and ToolNode instances
            if not callable(tool_func) and not hasattr(tool_func, "invoke"):
                raise ValueError("tool_node[0] must be a callable function or ToolNode")
        else:
            tool_func = tool_node
            tool_name = "TOOL"
            # Accept both callable functions and ToolNode instances
            # ToolNode instances have an 'invoke' method but are not callable
            if not callable(tool_func) and not hasattr(tool_func, "invoke"):
                raise ValueError("tool_node must be a callable function or ToolNode instance")

        self._graph.add_node(main_name, main_func)
        self._graph.add_node(tool_name, tool_func)

        # Now create edges
        self._graph.add_conditional_edges(
            main_name,
            _should_use_tools,
            {tool_name: tool_name, END: END},
        )

        self._graph.add_edge(tool_name, main_name)
        self._graph.set_entry_point(main_name)

        return self._graph.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            callback_manager=callback_manager,
        )
Functions
__init__
__init__(state=None, context_manager=None, publisher=None, id_generator=DefaultIDGenerator(), container=None)
Source code in pyagenity/prebuilt/agent/react.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def __init__(
    self,
    state: StateT | None = None,
    context_manager: BaseContextManager[StateT] | None = None,
    publisher: BasePublisher | None = None,
    id_generator: BaseIDGenerator = DefaultIDGenerator(),
    container: InjectQ | None = None,
):
    self._graph = StateGraph[StateT](
        state=state,
        context_manager=context_manager,
        publisher=publisher,
        id_generator=id_generator,
        container=container,
    )
compile
compile(main_node, tool_node, checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())
Source code in pyagenity/prebuilt/agent/react.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def compile(
    self,
    main_node: tuple[Callable, str] | Callable,
    tool_node: tuple[Callable, str] | Callable,
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> CompiledGraph:
    # Determine main node function and name
    if isinstance(main_node, tuple):
        main_func, main_name = main_node
        if not callable(main_func):
            raise ValueError("main_node[0] must be a callable function")
    else:
        main_func = main_node
        main_name = "MAIN"
        if not callable(main_func):
            raise ValueError("main_node must be a callable function")

    # Determine tool node function and name
    if isinstance(tool_node, tuple):
        tool_func, tool_name = tool_node
        # Accept both callable functions and ToolNode instances
        if not callable(tool_func) and not hasattr(tool_func, "invoke"):
            raise ValueError("tool_node[0] must be a callable function or ToolNode")
    else:
        tool_func = tool_node
        tool_name = "TOOL"
        # Accept both callable functions and ToolNode instances
        # ToolNode instances have an 'invoke' method but are not callable
        if not callable(tool_func) and not hasattr(tool_func, "invoke"):
            raise ValueError("tool_node must be a callable function or ToolNode instance")

    self._graph.add_node(main_name, main_func)
    self._graph.add_node(tool_name, tool_func)

    # Now create edges
    self._graph.add_conditional_edges(
        main_name,
        _should_use_tools,
        {tool_name: tool_name, END: END},
    )

    self._graph.add_edge(tool_name, main_name)
    self._graph.set_entry_point(main_name)

    return self._graph.compile(
        checkpointer=checkpointer,
        store=store,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        callback_manager=callback_manager,
    )
router

Classes:

Name Description
RouterAgent

A configurable router-style agent.

Attributes:

Name Type Description
StateT
Attributes
StateT module-attribute
StateT = TypeVar('StateT', bound=AgentState)
Classes
RouterAgent

A configurable router-style agent.

Pattern: - A router node runs (LLM or custom logic) and may update state/messages - A condition function inspects the state and returns a route key - Edges route to the matching node; each route returns back to ROUTER - Return END (via condition) to finish

Usage

router = RouterAgent() app = router.compile( router_node=my_router_func, # def my_router_func(state, config, ...) routes={ "search": search_node, "summarize": summarize_node, }, # Condition inspects state and returns one of the keys above or END condition=my_condition, # def my_condition(state) -> str # Optional explicit path map if returned keys differ from node names # path_map={"SEARCH": "search", "SUM": "summarize", END: END} )

Methods:

Name Description
__init__
compile
Source code in pyagenity/prebuilt/agent/router.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
class RouterAgent[StateT: AgentState]:
    """A configurable router-style agent.

    Pattern:
    - A router node runs (LLM or custom logic) and may update state/messages
    - A condition function inspects the state and returns a route key
    - Edges route to the matching node; each route returns back to ROUTER
    - Return END (via condition) to finish

    Usage:
        router = RouterAgent()
        app = router.compile(
            router_node=my_router_func,  # def my_router_func(state, config, ...)
            routes={
                "search": search_node,
                "summarize": summarize_node,
            },
            # Condition inspects state and returns one of the keys above or END
            condition=my_condition,  # def my_condition(state) -> str
            # Optional explicit path map if returned keys differ from node names
            # path_map={"SEARCH": "search", "SUM": "summarize", END: END}
        )
    """

    def __init__(
        self,
        state: StateT | None = None,
        context_manager: BaseContextManager[StateT] | None = None,
        publisher: BasePublisher | None = None,
        id_generator: BaseIDGenerator = DefaultIDGenerator(),
        container: InjectQ | None = None,
    ):
        self._graph = StateGraph[StateT](
            state=state,
            context_manager=context_manager,
            publisher=publisher,
            id_generator=id_generator,
            container=container,
        )

    def compile(  # noqa: PLR0912
        self,
        router_node: Callable | tuple[Callable, str],
        routes: dict[str, Callable | ToolNode | tuple[Callable | ToolNode, str]],
        condition: Callable[[AgentState], str] | None = None,
        path_map: dict[str, str] | None = None,
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> CompiledGraph:
        # Handle router_node
        if isinstance(router_node, tuple):
            router_func, router_name = router_node
            if not callable(router_func):
                raise ValueError("router_node[0] must be callable")
        else:
            router_func = router_node
            router_name = "ROUTER"
            if not callable(router_func):
                raise ValueError("router_node must be callable")

        if not routes:
            raise ValueError("routes must be a non-empty dict of name -> callable/ToolNode/tuple")

        # Add route nodes
        route_names = []
        for key, func in routes.items():
            if isinstance(func, tuple):
                route_func, route_name = func
                if not (callable(route_func) or isinstance(route_func, ToolNode)):
                    raise ValueError(f"Route '{key}'[0] must be callable or ToolNode")
            else:
                route_func = func
                route_name = key
                if not (callable(route_func) or isinstance(route_func, ToolNode)):
                    raise ValueError(f"Route '{key}' must be callable or ToolNode")
            self._graph.add_node(route_name, route_func)
            route_names.append(route_name)

        # Add router node as entry
        self._graph.add_node(router_name, router_func)

        # Build default condition/path_map if needed
        if condition is None and len(route_names) == 1:
            only = route_names[0]

            def _always(_: AgentState) -> str:
                return only

            condition = _always
            path_map = {only: only, END: END}

        if condition is None and len(route_names) > 1:
            raise ValueError("condition must be provided when multiple routes are defined")

        # If path_map is not provided, assume router returns exact route names
        if path_map is None:
            path_map = {k: k for k in route_names}
            path_map[END] = END

        # Conditional edges from router node based on condition results
        self._graph.add_conditional_edges(
            router_name,
            condition,  # type: ignore[arg-type]
            path_map,
        )

        # Loop back to router node from each route node
        for name in route_names:
            self._graph.add_edge(name, router_name)

        # Entry
        self._graph.set_entry_point(router_name)

        return self._graph.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            callback_manager=callback_manager,
        )
Functions
__init__
__init__(state=None, context_manager=None, publisher=None, id_generator=DefaultIDGenerator(), container=None)
Source code in pyagenity/prebuilt/agent/router.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def __init__(
    self,
    state: StateT | None = None,
    context_manager: BaseContextManager[StateT] | None = None,
    publisher: BasePublisher | None = None,
    id_generator: BaseIDGenerator = DefaultIDGenerator(),
    container: InjectQ | None = None,
):
    self._graph = StateGraph[StateT](
        state=state,
        context_manager=context_manager,
        publisher=publisher,
        id_generator=id_generator,
        container=container,
    )
compile
compile(router_node, routes, condition=None, path_map=None, checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())
Source code in pyagenity/prebuilt/agent/router.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def compile(  # noqa: PLR0912
    self,
    router_node: Callable | tuple[Callable, str],
    routes: dict[str, Callable | ToolNode | tuple[Callable | ToolNode, str]],
    condition: Callable[[AgentState], str] | None = None,
    path_map: dict[str, str] | None = None,
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> CompiledGraph:
    # Handle router_node
    if isinstance(router_node, tuple):
        router_func, router_name = router_node
        if not callable(router_func):
            raise ValueError("router_node[0] must be callable")
    else:
        router_func = router_node
        router_name = "ROUTER"
        if not callable(router_func):
            raise ValueError("router_node must be callable")

    if not routes:
        raise ValueError("routes must be a non-empty dict of name -> callable/ToolNode/tuple")

    # Add route nodes
    route_names = []
    for key, func in routes.items():
        if isinstance(func, tuple):
            route_func, route_name = func
            if not (callable(route_func) or isinstance(route_func, ToolNode)):
                raise ValueError(f"Route '{key}'[0] must be callable or ToolNode")
        else:
            route_func = func
            route_name = key
            if not (callable(route_func) or isinstance(route_func, ToolNode)):
                raise ValueError(f"Route '{key}' must be callable or ToolNode")
        self._graph.add_node(route_name, route_func)
        route_names.append(route_name)

    # Add router node as entry
    self._graph.add_node(router_name, router_func)

    # Build default condition/path_map if needed
    if condition is None and len(route_names) == 1:
        only = route_names[0]

        def _always(_: AgentState) -> str:
            return only

        condition = _always
        path_map = {only: only, END: END}

    if condition is None and len(route_names) > 1:
        raise ValueError("condition must be provided when multiple routes are defined")

    # If path_map is not provided, assume router returns exact route names
    if path_map is None:
        path_map = {k: k for k in route_names}
        path_map[END] = END

    # Conditional edges from router node based on condition results
    self._graph.add_conditional_edges(
        router_name,
        condition,  # type: ignore[arg-type]
        path_map,
    )

    # Loop back to router node from each route node
    for name in route_names:
        self._graph.add_edge(name, router_name)

    # Entry
    self._graph.set_entry_point(router_name)

    return self._graph.compile(
        checkpointer=checkpointer,
        store=store,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        callback_manager=callback_manager,
    )
sequential

Classes:

Name Description
SequentialAgent

A simple sequential agent that executes a fixed pipeline of nodes.

Attributes:

Name Type Description
StateT
Attributes
StateT module-attribute
StateT = TypeVar('StateT', bound=AgentState)
Classes
SequentialAgent

A simple sequential agent that executes a fixed pipeline of nodes.

Pattern: - Nodes run in the provided order: step1 -> step2 -> ... -> stepN - After the last step, the graph ends

Usage

seq = SequentialAgent() app = seq.compile([ ("ingest", ingest_node), ("plan", plan_node), ("execute", execute_node), ])

Methods:

Name Description
__init__
compile
Source code in pyagenity/prebuilt/agent/sequential.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
class SequentialAgent[StateT: AgentState]:
    """A simple sequential agent that executes a fixed pipeline of nodes.

    Pattern:
    - Nodes run in the provided order: step1 -> step2 -> ... -> stepN
    - After the last step, the graph ends

    Usage:
        seq = SequentialAgent()
        app = seq.compile([
            ("ingest", ingest_node),
            ("plan", plan_node),
            ("execute", execute_node),
        ])
    """

    def __init__(
        self,
        state: StateT | None = None,
        context_manager: BaseContextManager[StateT] | None = None,
        publisher: BasePublisher | None = None,
        id_generator: BaseIDGenerator = DefaultIDGenerator(),
        container: InjectQ | None = None,
    ):
        self._graph = StateGraph[StateT](
            state=state,
            context_manager=context_manager,
            publisher=publisher,
            id_generator=id_generator,
            container=container,
        )

    def compile(
        self,
        steps: Sequence[tuple[str, Callable | ToolNode] | tuple[Callable | ToolNode, str]],
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> CompiledGraph:
        if not steps or len(steps) == 0:
            raise ValueError(
                "steps must be a non-empty sequence of (name, callable/ToolNode) o"
                "or (callable/ToolNode, name)"
            )

        # Add nodes
        step_names = []
        for step in steps:
            if isinstance(step[0], str):
                name, func = step
            else:
                func, name = step
            if not (callable(func) or isinstance(func, ToolNode)):
                raise ValueError(f"Step '{name}' must be a callable or ToolNode")
            self._graph.add_node(name, func)  # type: ignore[arg-type]
            step_names.append(name)

        # Static edges in order
        for i in range(len(step_names) - 1):
            self._graph.add_edge(step_names[i], step_names[i + 1])

        # Entry is the first step
        self._graph.set_entry_point(step_names[0])

        # No explicit edge to END needed; the engine will end if no outgoing edges remain.
        return self._graph.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            callback_manager=callback_manager,
        )
Functions
__init__
__init__(state=None, context_manager=None, publisher=None, id_generator=DefaultIDGenerator(), container=None)
Source code in pyagenity/prebuilt/agent/sequential.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def __init__(
    self,
    state: StateT | None = None,
    context_manager: BaseContextManager[StateT] | None = None,
    publisher: BasePublisher | None = None,
    id_generator: BaseIDGenerator = DefaultIDGenerator(),
    container: InjectQ | None = None,
):
    self._graph = StateGraph[StateT](
        state=state,
        context_manager=context_manager,
        publisher=publisher,
        id_generator=id_generator,
        container=container,
    )
compile
compile(steps, checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())
Source code in pyagenity/prebuilt/agent/sequential.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def compile(
    self,
    steps: Sequence[tuple[str, Callable | ToolNode] | tuple[Callable | ToolNode, str]],
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> CompiledGraph:
    if not steps or len(steps) == 0:
        raise ValueError(
            "steps must be a non-empty sequence of (name, callable/ToolNode) o"
            "or (callable/ToolNode, name)"
        )

    # Add nodes
    step_names = []
    for step in steps:
        if isinstance(step[0], str):
            name, func = step
        else:
            func, name = step
        if not (callable(func) or isinstance(func, ToolNode)):
            raise ValueError(f"Step '{name}' must be a callable or ToolNode")
        self._graph.add_node(name, func)  # type: ignore[arg-type]
        step_names.append(name)

    # Static edges in order
    for i in range(len(step_names) - 1):
        self._graph.add_edge(step_names[i], step_names[i + 1])

    # Entry is the first step
    self._graph.set_entry_point(step_names[0])

    # No explicit edge to END needed; the engine will end if no outgoing edges remain.
    return self._graph.compile(
        checkpointer=checkpointer,
        store=store,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        callback_manager=callback_manager,
    )
supervisor_team

Classes:

Name Description
SupervisorTeamAgent

Supervisor routes tasks to worker nodes and aggregates results.

Attributes:

Name Type Description
StateT
Attributes
StateT module-attribute
StateT = TypeVar('StateT', bound=AgentState)
Classes
SupervisorTeamAgent

Supervisor routes tasks to worker nodes and aggregates results.

Nodes: - SUPERVISOR: decides which worker to call (by returning a worker key) or END - Multiple WORKER nodes: functions or ToolNode instances - AGGREGATE: optional aggregator node after worker runs; loops back to SUPERVISOR

The compile requires

supervisor_node: Callable workers: dict[str, Callable|ToolNode] aggregate_node: Callable | None condition: Callable[[AgentState], str] returns worker key or END

Methods:

Name Description
__init__
compile
Source code in pyagenity/prebuilt/agent/supervisor_team.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
class SupervisorTeamAgent[StateT: AgentState]:
    """Supervisor routes tasks to worker nodes and aggregates results.

    Nodes:
    - SUPERVISOR: decides which worker to call (by returning a worker key) or END
    - Multiple WORKER nodes: functions or ToolNode instances
    - AGGREGATE: optional aggregator node after worker runs; loops back to SUPERVISOR

    The compile requires:
      supervisor_node: Callable
      workers: dict[str, Callable|ToolNode]
      aggregate_node: Callable | None
      condition: Callable[[AgentState], str] returns worker key or END
    """

    def __init__(
        self,
        state: StateT | None = None,
        context_manager: BaseContextManager[StateT] | None = None,
        publisher: BasePublisher | None = None,
        id_generator: BaseIDGenerator = DefaultIDGenerator(),
        container: InjectQ | None = None,
    ):
        self._graph = StateGraph[StateT](
            state=state,
            context_manager=context_manager,
            publisher=publisher,
            id_generator=id_generator,
            container=container,
        )

    def compile(  # noqa: PLR0912
        self,
        supervisor_node: Callable | tuple[Callable, str],
        workers: dict[str, Callable | ToolNode | tuple[Callable | ToolNode, str]],
        condition: Callable[[AgentState], str],
        aggregate_node: Callable | tuple[Callable, str] | None = None,
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> CompiledGraph:
        # Handle supervisor_node
        if isinstance(supervisor_node, tuple):
            supervisor_func, supervisor_name = supervisor_node
            if not callable(supervisor_func):
                raise ValueError("supervisor_node[0] must be callable")
        else:
            supervisor_func = supervisor_node
            supervisor_name = "SUPERVISOR"
            if not callable(supervisor_func):
                raise ValueError("supervisor_node must be callable")

        if not workers:
            raise ValueError("workers must be a non-empty dict")

        # Add worker nodes
        worker_names = []
        for key, fn in workers.items():
            if isinstance(fn, tuple):
                worker_func, worker_name = fn
                if not (callable(worker_func) or isinstance(worker_func, ToolNode)):
                    raise ValueError(f"Worker '{key}'[0] must be callable or ToolNode")
            else:
                worker_func = fn
                worker_name = key
                if not (callable(worker_func) or isinstance(worker_func, ToolNode)):
                    raise ValueError(f"Worker '{key}' must be callable or ToolNode")
            self._graph.add_node(worker_name, worker_func)
            worker_names.append(worker_name)

        # Handle aggregate_node
        aggregate_name = "AGGREGATE"
        if aggregate_node:
            if isinstance(aggregate_node, tuple):
                aggregate_func, aggregate_name = aggregate_node
                if not callable(aggregate_func):
                    raise ValueError("aggregate_node[0] must be callable")
            else:
                aggregate_func = aggregate_node
                aggregate_name = "AGGREGATE"
                if not callable(aggregate_func):
                    raise ValueError("aggregate_node must be callable")
            self._graph.add_node(aggregate_name, aggregate_func)

        # SUPERVISOR decides next worker
        path_map = {k: k for k in worker_names}
        path_map[END] = END
        self._graph.add_conditional_edges(supervisor_name, condition, path_map)

        # After worker, go to AGGREGATE if present, else back to SUPERVISOR
        for name in worker_names:
            self._graph.add_edge(name, aggregate_name if aggregate_node else supervisor_name)

        if aggregate_node:
            self._graph.add_edge(aggregate_name, supervisor_name)

        self._graph.set_entry_point(supervisor_name)

        return self._graph.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            callback_manager=callback_manager,
        )
Functions
__init__
__init__(state=None, context_manager=None, publisher=None, id_generator=DefaultIDGenerator(), container=None)
Source code in pyagenity/prebuilt/agent/supervisor_team.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def __init__(
    self,
    state: StateT | None = None,
    context_manager: BaseContextManager[StateT] | None = None,
    publisher: BasePublisher | None = None,
    id_generator: BaseIDGenerator = DefaultIDGenerator(),
    container: InjectQ | None = None,
):
    self._graph = StateGraph[StateT](
        state=state,
        context_manager=context_manager,
        publisher=publisher,
        id_generator=id_generator,
        container=container,
    )
compile
compile(supervisor_node, workers, condition, aggregate_node=None, checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())
Source code in pyagenity/prebuilt/agent/supervisor_team.py
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def compile(  # noqa: PLR0912
    self,
    supervisor_node: Callable | tuple[Callable, str],
    workers: dict[str, Callable | ToolNode | tuple[Callable | ToolNode, str]],
    condition: Callable[[AgentState], str],
    aggregate_node: Callable | tuple[Callable, str] | None = None,
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> CompiledGraph:
    # Handle supervisor_node
    if isinstance(supervisor_node, tuple):
        supervisor_func, supervisor_name = supervisor_node
        if not callable(supervisor_func):
            raise ValueError("supervisor_node[0] must be callable")
    else:
        supervisor_func = supervisor_node
        supervisor_name = "SUPERVISOR"
        if not callable(supervisor_func):
            raise ValueError("supervisor_node must be callable")

    if not workers:
        raise ValueError("workers must be a non-empty dict")

    # Add worker nodes
    worker_names = []
    for key, fn in workers.items():
        if isinstance(fn, tuple):
            worker_func, worker_name = fn
            if not (callable(worker_func) or isinstance(worker_func, ToolNode)):
                raise ValueError(f"Worker '{key}'[0] must be callable or ToolNode")
        else:
            worker_func = fn
            worker_name = key
            if not (callable(worker_func) or isinstance(worker_func, ToolNode)):
                raise ValueError(f"Worker '{key}' must be callable or ToolNode")
        self._graph.add_node(worker_name, worker_func)
        worker_names.append(worker_name)

    # Handle aggregate_node
    aggregate_name = "AGGREGATE"
    if aggregate_node:
        if isinstance(aggregate_node, tuple):
            aggregate_func, aggregate_name = aggregate_node
            if not callable(aggregate_func):
                raise ValueError("aggregate_node[0] must be callable")
        else:
            aggregate_func = aggregate_node
            aggregate_name = "AGGREGATE"
            if not callable(aggregate_func):
                raise ValueError("aggregate_node must be callable")
        self._graph.add_node(aggregate_name, aggregate_func)

    # SUPERVISOR decides next worker
    path_map = {k: k for k in worker_names}
    path_map[END] = END
    self._graph.add_conditional_edges(supervisor_name, condition, path_map)

    # After worker, go to AGGREGATE if present, else back to SUPERVISOR
    for name in worker_names:
        self._graph.add_edge(name, aggregate_name if aggregate_node else supervisor_name)

    if aggregate_node:
        self._graph.add_edge(aggregate_name, supervisor_name)

    self._graph.set_entry_point(supervisor_name)

    return self._graph.compile(
        checkpointer=checkpointer,
        store=store,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        callback_manager=callback_manager,
    )
swarm

Classes:

Name Description
SwarmAgent

Swarm pattern: dispatch to many workers, collect, then reach consensus.

Attributes:

Name Type Description
StateT
Attributes
StateT module-attribute
StateT = TypeVar('StateT', bound=AgentState)
Classes
SwarmAgent

Swarm pattern: dispatch to many workers, collect, then reach consensus.

Notes: - The underlying engine executes nodes sequentially; true parallelism isn't performed at the graph level. For concurrency, worker/collector nodes can internally use BackgroundTaskManager or async to fan-out. - This pattern wires a linear broadcast-collect chain ending in CONSENSUS.

Nodes: - optional DISPATCH: prepare/plan the swarm task - WORKER_i: a set of worker nodes (Callable or ToolNode) - optional COLLECT: consolidate each worker's result into shared state - CONSENSUS: aggregate all collected results and produce final output

Methods:

Name Description
__init__
compile
Source code in pyagenity/prebuilt/agent/swarm.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
class SwarmAgent[StateT: AgentState]:
    """Swarm pattern: dispatch to many workers, collect, then reach consensus.

    Notes:
    - The underlying engine executes nodes sequentially; true parallelism isn't
      performed at the graph level. For concurrency, worker/collector nodes can
      internally use BackgroundTaskManager or async to fan-out.
    - This pattern wires a linear broadcast-collect chain ending in CONSENSUS.

    Nodes:
    - optional DISPATCH: prepare/plan the swarm task
    - WORKER_i: a set of worker nodes (Callable or ToolNode)
    - optional COLLECT: consolidate each worker's result into shared state
    - CONSENSUS: aggregate all collected results and produce final output
    """

    def __init__(
        self,
        state: StateT | None = None,
        context_manager: BaseContextManager[StateT] | None = None,
        publisher: BasePublisher | None = None,
        id_generator: BaseIDGenerator = DefaultIDGenerator(),
        container: InjectQ | None = None,
    ):
        self._graph = StateGraph[StateT](
            state=state,
            context_manager=context_manager,
            publisher=publisher,
            id_generator=id_generator,
            container=container,
        )

    def compile(
        self,
        workers: dict[str, Callable | ToolNode | tuple[Callable | ToolNode, str]],
        consensus_node: Callable | tuple[Callable, str],
        options: dict | None = None,
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> CompiledGraph:
        resolved_workers = self._add_worker_nodes(workers)
        worker_sequence = resolved_workers

        options = options or {}
        dispatch_node = options.get("dispatch")
        collect_node = options.get("collect")
        followup_condition = options.get("followup_condition")

        dispatch_name = self._resolve_dispatch(dispatch_node)
        collect_info = self._resolve_collect(collect_node)
        consensus_name = self._resolve_consensus(consensus_node)

        entry = dispatch_name or worker_sequence[0]
        self._graph.set_entry_point(entry)
        if dispatch_name:
            self._graph.add_edge(dispatch_name, worker_sequence[0])

        self._wire_edges(worker_sequence, collect_info, consensus_name)

        if followup_condition is None:

            def _cond(_: AgentState) -> str:
                return END

            followup_condition = _cond

        self._graph.add_conditional_edges(
            consensus_name,
            followup_condition,
            {entry: entry, END: END},
        )

        return self._graph.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            callback_manager=callback_manager,
        )

    # ---- helpers ----
    def _add_worker_nodes(
        self,
        workers: dict[str, Callable | ToolNode | tuple[Callable | ToolNode, str]],
    ) -> list[str]:
        if not workers:
            raise ValueError("workers must be a non-empty dict")

        names: list[str] = []
        for key, fn in workers.items():
            if isinstance(fn, tuple):
                func, name = fn
            else:
                func, name = fn, key
            if not (callable(func) or isinstance(func, ToolNode)):
                raise ValueError(f"Worker '{key}' must be a callable or ToolNode")
            self._graph.add_node(name, func)
            names.append(name)
        return names

    def _resolve_dispatch(self, node: Callable | tuple[Callable, str] | None) -> str | None:
        if not node:
            return None
        if isinstance(node, tuple):
            func, name = node
        else:
            func, name = node, "DISPATCH"
        if not callable(func):
            raise ValueError("dispatch node must be callable")
        self._graph.add_node(name, func)
        return name

    def _resolve_collect(
        self,
        node: Callable | tuple[Callable, str] | None,
    ) -> tuple[Callable, str] | None:
        if not node:
            return None
        if isinstance(node, tuple):
            func, name = node
        else:
            func, name = node, "COLLECT"
        if not callable(func):
            raise ValueError("collect node must be callable")
        # Do not add a single shared collect node to avoid ambiguous routing.
        # We'll create per-worker collect nodes during wiring using this (func, base_name).
        return func, name

    def _resolve_consensus(self, node: Callable | tuple[Callable, str]) -> str:
        if isinstance(node, tuple):
            func, name = node
        else:
            func, name = node, "CONSENSUS"
        if not callable(func):
            raise ValueError("consensus node must be callable")
        self._graph.add_node(name, func)
        return name

    def _wire_edges(
        self,
        worker_sequence: list[str],
        collect_info: tuple[Callable, str] | None,
        consensus_name: str,
    ) -> None:
        for i, wname in enumerate(worker_sequence):
            is_last = i == len(worker_sequence) - 1
            target = consensus_name if is_last else worker_sequence[i + 1]
            if collect_info:
                cfunc, base = collect_info
                cname = f"{base}_{i + 1}"
                # Create a dedicated collect node per worker to prevent loops
                self._graph.add_node(cname, cfunc)
                self._graph.add_edge(wname, cname)
                self._graph.add_edge(cname, target)
            else:
                self._graph.add_edge(wname, target)
Functions
__init__
__init__(state=None, context_manager=None, publisher=None, id_generator=DefaultIDGenerator(), container=None)
Source code in pyagenity/prebuilt/agent/swarm.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def __init__(
    self,
    state: StateT | None = None,
    context_manager: BaseContextManager[StateT] | None = None,
    publisher: BasePublisher | None = None,
    id_generator: BaseIDGenerator = DefaultIDGenerator(),
    container: InjectQ | None = None,
):
    self._graph = StateGraph[StateT](
        state=state,
        context_manager=context_manager,
        publisher=publisher,
        id_generator=id_generator,
        container=container,
    )
compile
compile(workers, consensus_node, options=None, checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())
Source code in pyagenity/prebuilt/agent/swarm.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def compile(
    self,
    workers: dict[str, Callable | ToolNode | tuple[Callable | ToolNode, str]],
    consensus_node: Callable | tuple[Callable, str],
    options: dict | None = None,
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> CompiledGraph:
    resolved_workers = self._add_worker_nodes(workers)
    worker_sequence = resolved_workers

    options = options or {}
    dispatch_node = options.get("dispatch")
    collect_node = options.get("collect")
    followup_condition = options.get("followup_condition")

    dispatch_name = self._resolve_dispatch(dispatch_node)
    collect_info = self._resolve_collect(collect_node)
    consensus_name = self._resolve_consensus(consensus_node)

    entry = dispatch_name or worker_sequence[0]
    self._graph.set_entry_point(entry)
    if dispatch_name:
        self._graph.add_edge(dispatch_name, worker_sequence[0])

    self._wire_edges(worker_sequence, collect_info, consensus_name)

    if followup_condition is None:

        def _cond(_: AgentState) -> str:
            return END

        followup_condition = _cond

    self._graph.add_conditional_edges(
        consensus_name,
        followup_condition,
        {entry: entry, END: END},
    )

    return self._graph.compile(
        checkpointer=checkpointer,
        store=store,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        callback_manager=callback_manager,
    )

publisher

Publisher module for PyAgenity events.

This module provides publishers that handle the delivery of events to various outputs, such as console, Redis, Kafka, and RabbitMQ. Publishers are primarily used for logging and monitoring agent behavior, enabling real-time tracking of performance, usage, and debugging in agent graphs.

Key components: - BasePublisher: Abstract base class for all publishers, defining the interface for publishing event - ConsolePublisher: Default publisher that outputs events to the console for development and debugging - Optional publishers: RedisPublisher, KafkaPublisher, RabbitMQPublisher, which are available only if their dependencies are installed

Usage: - Import publishers: from pyagenity.publisher import ConsolePublisher, RedisPublisher (if available) - Instantiate and use in CompiledGraph: graph.compile(publisher=ConsolePublisher()). - Events are emitted as EventModel instances during graph execution, including node starts, completions, and errors.

Dependencies for optional publishers: - RedisPublisher: Requires 'redis.asyncio' (install via pip install redis). - KafkaPublisher: Requires 'aiokafka' (install via pip install aiokafka). - RabbitMQPublisher: Requires 'aio_pika' (install via pip install aio-pika).

For more details, see the individual publisher classes and the PyAgenity documentation.

Modules:

Name Description
base_publisher
console_publisher

Console publisher implementation for debugging and testing.

events

Event and streaming primitives for agent graph execution.

kafka_publisher

Kafka publisher implementation (optional dependency).

publish
rabbitmq_publisher

RabbitMQ publisher implementation (optional dependency).

redis_publisher

Redis publisher implementation (optional dependency).

Classes:

Name Description
BasePublisher

Abstract base class for event publishers.

ConsolePublisher

Publisher that prints events to the console for debugging and testing.

Attributes

KafkaPublisher module-attribute
KafkaPublisher = None
RabbitMQPublisher module-attribute
RabbitMQPublisher = None
RedisPublisher module-attribute
RedisPublisher = None
__all__ module-attribute
__all__ = ['BasePublisher', 'ConsolePublisher']

Classes

BasePublisher

Bases: ABC

Abstract base class for event publishers.

This class defines the interface for publishing events. Subclasses should implement the publish, close, and sync_close methods to provide specific publishing logic.

Attributes:

Name Type Description
config

Configuration dictionary for the publisher.

Methods:

Name Description
__init__

Initialize the publisher with the given configuration.

close

Close the publisher and release any resources.

publish

Publish an event.

sync_close

Close the publisher and release any resources (synchronous version).

Source code in pyagenity/publisher/base_publisher.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class BasePublisher(ABC):
    """Abstract base class for event publishers.

    This class defines the interface for publishing events. Subclasses should implement
    the publish, close, and sync_close methods to provide specific publishing logic.

    Attributes:
        config: Configuration dictionary for the publisher.
    """

    def __init__(self, config: dict[str, Any]):
        """Initialize the publisher with the given configuration.

        Args:
            config: Configuration dictionary for the publisher.
        """
        self.config = config

    @abstractmethod
    async def publish(self, event: EventModel) -> Any:
        """Publish an event.

        Args:
            event: The event to publish.

        Returns:
            The result of the publish operation.
        """
        raise NotImplementedError

    @abstractmethod
    async def close(self):
        """Close the publisher and release any resources.

        This method should be overridden by subclasses to provide specific cleanup logic.
        It will be called externally.
        """
        raise NotImplementedError

    @abstractmethod
    def sync_close(self):
        """Close the publisher and release any resources (synchronous version).

        This method should be overridden by subclasses to provide specific cleanup logic.
        It will be called externally.
        """
        raise NotImplementedError
Attributes
config instance-attribute
config = config
Functions
__init__
__init__(config)

Initialize the publisher with the given configuration.

Parameters:

Name Type Description Default
config dict[str, Any]

Configuration dictionary for the publisher.

required
Source code in pyagenity/publisher/base_publisher.py
17
18
19
20
21
22
23
def __init__(self, config: dict[str, Any]):
    """Initialize the publisher with the given configuration.

    Args:
        config: Configuration dictionary for the publisher.
    """
    self.config = config
close abstractmethod async
close()

Close the publisher and release any resources.

This method should be overridden by subclasses to provide specific cleanup logic. It will be called externally.

Source code in pyagenity/publisher/base_publisher.py
37
38
39
40
41
42
43
44
@abstractmethod
async def close(self):
    """Close the publisher and release any resources.

    This method should be overridden by subclasses to provide specific cleanup logic.
    It will be called externally.
    """
    raise NotImplementedError
publish abstractmethod async
publish(event)

Publish an event.

Parameters:

Name Type Description Default
event EventModel

The event to publish.

required

Returns:

Type Description
Any

The result of the publish operation.

Source code in pyagenity/publisher/base_publisher.py
25
26
27
28
29
30
31
32
33
34
35
@abstractmethod
async def publish(self, event: EventModel) -> Any:
    """Publish an event.

    Args:
        event: The event to publish.

    Returns:
        The result of the publish operation.
    """
    raise NotImplementedError
sync_close abstractmethod
sync_close()

Close the publisher and release any resources (synchronous version).

This method should be overridden by subclasses to provide specific cleanup logic. It will be called externally.

Source code in pyagenity/publisher/base_publisher.py
46
47
48
49
50
51
52
53
@abstractmethod
def sync_close(self):
    """Close the publisher and release any resources (synchronous version).

    This method should be overridden by subclasses to provide specific cleanup logic.
    It will be called externally.
    """
    raise NotImplementedError
ConsolePublisher

Bases: BasePublisher

Publisher that prints events to the console for debugging and testing.

This publisher is useful for development and debugging purposes, as it outputs event information to the standard output.

Attributes:

Name Type Description
format

Output format ('json' by default).

include_timestamp

Whether to include timestamp (True by default).

indent

Indentation for output (2 by default).

Methods:

Name Description
__init__

Initialize the ConsolePublisher with the given configuration.

close

Close the publisher and release any resources.

publish

Publish an event to the console.

sync_close

Synchronously close the publisher and release any resources.

Source code in pyagenity/publisher/console_publisher.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class ConsolePublisher(BasePublisher):
    """Publisher that prints events to the console for debugging and testing.

    This publisher is useful for development and debugging purposes, as it outputs event information
    to the standard output.

    Attributes:
        format: Output format ('json' by default).
        include_timestamp: Whether to include timestamp (True by default).
        indent: Indentation for output (2 by default).
    """

    def __init__(self, config: dict[str, Any] | None = None):
        """Initialize the ConsolePublisher with the given configuration.

        Args:
            config: Configuration dictionary. Supported keys:
                - format: Output format (default: 'json').
                - include_timestamp: Whether to include timestamp (default: True).
                - indent: Indentation for output (default: 2).
        """
        super().__init__(config or {})
        self.format = config.get("format", "json") if config else "json"
        self.include_timestamp = config.get("include_timestamp", True) if config else True
        self.indent = config.get("indent", 2) if config else 2

    async def publish(self, event: EventModel) -> Any:
        """Publish an event to the console.

        Args:
            event: The event to publish.

        Returns:
            None
        """
        msg = f"{event.timestamp} -> Source: {event.node_name}.{event.event_type}:"
        msg += f"-> Payload: {event.data}"
        msg += f" -> {event.metadata}"
        print(msg)  # noqa: T201

    async def close(self):
        """Close the publisher and release any resources.

        ConsolePublisher does not require cleanup, but this method is provided for
        interface compatibility.
        """
        logger.debug("ConsolePublisher closed")

    def sync_close(self):
        """Synchronously close the publisher and release any resources.

        ConsolePublisher does not require cleanup, but this method is provided for
        interface compatibility.
        """
        logger.debug("ConsolePublisher sync closed")
Attributes
config instance-attribute
config = config
format instance-attribute
format = get('format', 'json') if config else 'json'
include_timestamp instance-attribute
include_timestamp = get('include_timestamp', True) if config else True
indent instance-attribute
indent = get('indent', 2) if config else 2
Functions
__init__
__init__(config=None)

Initialize the ConsolePublisher with the given configuration.

Parameters:

Name Type Description Default
config dict[str, Any] | None

Configuration dictionary. Supported keys: - format: Output format (default: 'json'). - include_timestamp: Whether to include timestamp (default: True). - indent: Indentation for output (default: 2).

None
Source code in pyagenity/publisher/console_publisher.py
29
30
31
32
33
34
35
36
37
38
39
40
41
def __init__(self, config: dict[str, Any] | None = None):
    """Initialize the ConsolePublisher with the given configuration.

    Args:
        config: Configuration dictionary. Supported keys:
            - format: Output format (default: 'json').
            - include_timestamp: Whether to include timestamp (default: True).
            - indent: Indentation for output (default: 2).
    """
    super().__init__(config or {})
    self.format = config.get("format", "json") if config else "json"
    self.include_timestamp = config.get("include_timestamp", True) if config else True
    self.indent = config.get("indent", 2) if config else 2
close async
close()

Close the publisher and release any resources.

ConsolePublisher does not require cleanup, but this method is provided for interface compatibility.

Source code in pyagenity/publisher/console_publisher.py
57
58
59
60
61
62
63
async def close(self):
    """Close the publisher and release any resources.

    ConsolePublisher does not require cleanup, but this method is provided for
    interface compatibility.
    """
    logger.debug("ConsolePublisher closed")
publish async
publish(event)

Publish an event to the console.

Parameters:

Name Type Description Default
event EventModel

The event to publish.

required

Returns:

Type Description
Any

None

Source code in pyagenity/publisher/console_publisher.py
43
44
45
46
47
48
49
50
51
52
53
54
55
async def publish(self, event: EventModel) -> Any:
    """Publish an event to the console.

    Args:
        event: The event to publish.

    Returns:
        None
    """
    msg = f"{event.timestamp} -> Source: {event.node_name}.{event.event_type}:"
    msg += f"-> Payload: {event.data}"
    msg += f" -> {event.metadata}"
    print(msg)  # noqa: T201
sync_close
sync_close()

Synchronously close the publisher and release any resources.

ConsolePublisher does not require cleanup, but this method is provided for interface compatibility.

Source code in pyagenity/publisher/console_publisher.py
65
66
67
68
69
70
71
def sync_close(self):
    """Synchronously close the publisher and release any resources.

    ConsolePublisher does not require cleanup, but this method is provided for
    interface compatibility.
    """
    logger.debug("ConsolePublisher sync closed")

Modules

base_publisher

Classes:

Name Description
BasePublisher

Abstract base class for event publishers.

Classes
BasePublisher

Bases: ABC

Abstract base class for event publishers.

This class defines the interface for publishing events. Subclasses should implement the publish, close, and sync_close methods to provide specific publishing logic.

Attributes:

Name Type Description
config

Configuration dictionary for the publisher.

Methods:

Name Description
__init__

Initialize the publisher with the given configuration.

close

Close the publisher and release any resources.

publish

Publish an event.

sync_close

Close the publisher and release any resources (synchronous version).

Source code in pyagenity/publisher/base_publisher.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class BasePublisher(ABC):
    """Abstract base class for event publishers.

    This class defines the interface for publishing events. Subclasses should implement
    the publish, close, and sync_close methods to provide specific publishing logic.

    Attributes:
        config: Configuration dictionary for the publisher.
    """

    def __init__(self, config: dict[str, Any]):
        """Initialize the publisher with the given configuration.

        Args:
            config: Configuration dictionary for the publisher.
        """
        self.config = config

    @abstractmethod
    async def publish(self, event: EventModel) -> Any:
        """Publish an event.

        Args:
            event: The event to publish.

        Returns:
            The result of the publish operation.
        """
        raise NotImplementedError

    @abstractmethod
    async def close(self):
        """Close the publisher and release any resources.

        This method should be overridden by subclasses to provide specific cleanup logic.
        It will be called externally.
        """
        raise NotImplementedError

    @abstractmethod
    def sync_close(self):
        """Close the publisher and release any resources (synchronous version).

        This method should be overridden by subclasses to provide specific cleanup logic.
        It will be called externally.
        """
        raise NotImplementedError
Attributes
config instance-attribute
config = config
Functions
__init__
__init__(config)

Initialize the publisher with the given configuration.

Parameters:

Name Type Description Default
config dict[str, Any]

Configuration dictionary for the publisher.

required
Source code in pyagenity/publisher/base_publisher.py
17
18
19
20
21
22
23
def __init__(self, config: dict[str, Any]):
    """Initialize the publisher with the given configuration.

    Args:
        config: Configuration dictionary for the publisher.
    """
    self.config = config
close abstractmethod async
close()

Close the publisher and release any resources.

This method should be overridden by subclasses to provide specific cleanup logic. It will be called externally.

Source code in pyagenity/publisher/base_publisher.py
37
38
39
40
41
42
43
44
@abstractmethod
async def close(self):
    """Close the publisher and release any resources.

    This method should be overridden by subclasses to provide specific cleanup logic.
    It will be called externally.
    """
    raise NotImplementedError
publish abstractmethod async
publish(event)

Publish an event.

Parameters:

Name Type Description Default
event EventModel

The event to publish.

required

Returns:

Type Description
Any

The result of the publish operation.

Source code in pyagenity/publisher/base_publisher.py
25
26
27
28
29
30
31
32
33
34
35
@abstractmethod
async def publish(self, event: EventModel) -> Any:
    """Publish an event.

    Args:
        event: The event to publish.

    Returns:
        The result of the publish operation.
    """
    raise NotImplementedError
sync_close abstractmethod
sync_close()

Close the publisher and release any resources (synchronous version).

This method should be overridden by subclasses to provide specific cleanup logic. It will be called externally.

Source code in pyagenity/publisher/base_publisher.py
46
47
48
49
50
51
52
53
@abstractmethod
def sync_close(self):
    """Close the publisher and release any resources (synchronous version).

    This method should be overridden by subclasses to provide specific cleanup logic.
    It will be called externally.
    """
    raise NotImplementedError
console_publisher

Console publisher implementation for debugging and testing.

This module provides a publisher that outputs events to the console for development and debugging purposes.

Classes:

Name Description
ConsolePublisher

Publisher that prints events to the console for debugging and testing.

Attributes:

Name Type Description
logger
Attributes
logger module-attribute
logger = getLogger(__name__)
Classes
ConsolePublisher

Bases: BasePublisher

Publisher that prints events to the console for debugging and testing.

This publisher is useful for development and debugging purposes, as it outputs event information to the standard output.

Attributes:

Name Type Description
format

Output format ('json' by default).

include_timestamp

Whether to include timestamp (True by default).

indent

Indentation for output (2 by default).

Methods:

Name Description
__init__

Initialize the ConsolePublisher with the given configuration.

close

Close the publisher and release any resources.

publish

Publish an event to the console.

sync_close

Synchronously close the publisher and release any resources.

Source code in pyagenity/publisher/console_publisher.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class ConsolePublisher(BasePublisher):
    """Publisher that prints events to the console for debugging and testing.

    This publisher is useful for development and debugging purposes, as it outputs event information
    to the standard output.

    Attributes:
        format: Output format ('json' by default).
        include_timestamp: Whether to include timestamp (True by default).
        indent: Indentation for output (2 by default).
    """

    def __init__(self, config: dict[str, Any] | None = None):
        """Initialize the ConsolePublisher with the given configuration.

        Args:
            config: Configuration dictionary. Supported keys:
                - format: Output format (default: 'json').
                - include_timestamp: Whether to include timestamp (default: True).
                - indent: Indentation for output (default: 2).
        """
        super().__init__(config or {})
        self.format = config.get("format", "json") if config else "json"
        self.include_timestamp = config.get("include_timestamp", True) if config else True
        self.indent = config.get("indent", 2) if config else 2

    async def publish(self, event: EventModel) -> Any:
        """Publish an event to the console.

        Args:
            event: The event to publish.

        Returns:
            None
        """
        msg = f"{event.timestamp} -> Source: {event.node_name}.{event.event_type}:"
        msg += f"-> Payload: {event.data}"
        msg += f" -> {event.metadata}"
        print(msg)  # noqa: T201

    async def close(self):
        """Close the publisher and release any resources.

        ConsolePublisher does not require cleanup, but this method is provided for
        interface compatibility.
        """
        logger.debug("ConsolePublisher closed")

    def sync_close(self):
        """Synchronously close the publisher and release any resources.

        ConsolePublisher does not require cleanup, but this method is provided for
        interface compatibility.
        """
        logger.debug("ConsolePublisher sync closed")
Attributes
config instance-attribute
config = config
format instance-attribute
format = get('format', 'json') if config else 'json'
include_timestamp instance-attribute
include_timestamp = get('include_timestamp', True) if config else True
indent instance-attribute
indent = get('indent', 2) if config else 2
Functions
__init__
__init__(config=None)

Initialize the ConsolePublisher with the given configuration.

Parameters:

Name Type Description Default
config dict[str, Any] | None

Configuration dictionary. Supported keys: - format: Output format (default: 'json'). - include_timestamp: Whether to include timestamp (default: True). - indent: Indentation for output (default: 2).

None
Source code in pyagenity/publisher/console_publisher.py
29
30
31
32
33
34
35
36
37
38
39
40
41
def __init__(self, config: dict[str, Any] | None = None):
    """Initialize the ConsolePublisher with the given configuration.

    Args:
        config: Configuration dictionary. Supported keys:
            - format: Output format (default: 'json').
            - include_timestamp: Whether to include timestamp (default: True).
            - indent: Indentation for output (default: 2).
    """
    super().__init__(config or {})
    self.format = config.get("format", "json") if config else "json"
    self.include_timestamp = config.get("include_timestamp", True) if config else True
    self.indent = config.get("indent", 2) if config else 2
close async
close()

Close the publisher and release any resources.

ConsolePublisher does not require cleanup, but this method is provided for interface compatibility.

Source code in pyagenity/publisher/console_publisher.py
57
58
59
60
61
62
63
async def close(self):
    """Close the publisher and release any resources.

    ConsolePublisher does not require cleanup, but this method is provided for
    interface compatibility.
    """
    logger.debug("ConsolePublisher closed")
publish async
publish(event)

Publish an event to the console.

Parameters:

Name Type Description Default
event EventModel

The event to publish.

required

Returns:

Type Description
Any

None

Source code in pyagenity/publisher/console_publisher.py
43
44
45
46
47
48
49
50
51
52
53
54
55
async def publish(self, event: EventModel) -> Any:
    """Publish an event to the console.

    Args:
        event: The event to publish.

    Returns:
        None
    """
    msg = f"{event.timestamp} -> Source: {event.node_name}.{event.event_type}:"
    msg += f"-> Payload: {event.data}"
    msg += f" -> {event.metadata}"
    print(msg)  # noqa: T201
sync_close
sync_close()

Synchronously close the publisher and release any resources.

ConsolePublisher does not require cleanup, but this method is provided for interface compatibility.

Source code in pyagenity/publisher/console_publisher.py
65
66
67
68
69
70
71
def sync_close(self):
    """Synchronously close the publisher and release any resources.

    ConsolePublisher does not require cleanup, but this method is provided for
    interface compatibility.
    """
    logger.debug("ConsolePublisher sync closed")
events

Event and streaming primitives for agent graph execution.

This module defines event types, content types, and the EventModel for structured streaming of execution updates, tool calls, state changes, messages, and errors in agent graphs.

Classes:

Name Description
Event

Enum for event sources (graph, node, tool, streaming).

EventType

Enum for event phases (start, progress, result, end, etc.).

ContentType

Enum for semantic content types (text, message, tool_call, etc.).

EventModel

Structured event chunk for streaming agent graph execution.

Attributes
Classes
ContentType

Bases: str, Enum

Enum for semantic content types in agent graph streaming.

Values

TEXT: Textual content. MESSAGE: Message content. REASONING: Reasoning content. TOOL_CALL: Tool call content. TOOL_RESULT: Tool result content. IMAGE: Image content. AUDIO: Audio content. VIDEO: Video content. DOCUMENT: Document content. DATA: Data content. STATE: State content. UPDATE: Update content. ERROR: Error content.

Attributes:

Name Type Description
AUDIO
DATA
DOCUMENT
ERROR
IMAGE
MESSAGE
REASONING
STATE
TEXT
TOOL_CALL
TOOL_RESULT
UPDATE
VIDEO
Source code in pyagenity/publisher/events.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class ContentType(str, enum.Enum):
    """Enum for semantic content types in agent graph streaming.

    Values:
        TEXT: Textual content.
        MESSAGE: Message content.
        REASONING: Reasoning content.
        TOOL_CALL: Tool call content.
        TOOL_RESULT: Tool result content.
        IMAGE: Image content.
        AUDIO: Audio content.
        VIDEO: Video content.
        DOCUMENT: Document content.
        DATA: Data content.
        STATE: State content.
        UPDATE: Update content.
        ERROR: Error content.
    """

    TEXT = "text"
    MESSAGE = "message"
    REASONING = "reasoning"
    TOOL_CALL = "tool_call"
    TOOL_RESULT = "tool_result"
    IMAGE = "image"
    AUDIO = "audio"
    VIDEO = "video"
    DOCUMENT = "document"
    DATA = "data"
    STATE = "state"
    UPDATE = "update"
    ERROR = "error"
Attributes
AUDIO class-attribute instance-attribute
AUDIO = 'audio'
DATA class-attribute instance-attribute
DATA = 'data'
DOCUMENT class-attribute instance-attribute
DOCUMENT = 'document'
ERROR class-attribute instance-attribute
ERROR = 'error'
IMAGE class-attribute instance-attribute
IMAGE = 'image'
MESSAGE class-attribute instance-attribute
MESSAGE = 'message'
REASONING class-attribute instance-attribute
REASONING = 'reasoning'
STATE class-attribute instance-attribute
STATE = 'state'
TEXT class-attribute instance-attribute
TEXT = 'text'
TOOL_CALL class-attribute instance-attribute
TOOL_CALL = 'tool_call'
TOOL_RESULT class-attribute instance-attribute
TOOL_RESULT = 'tool_result'
UPDATE class-attribute instance-attribute
UPDATE = 'update'
VIDEO class-attribute instance-attribute
VIDEO = 'video'
Event

Bases: str, Enum

Enum for event sources in agent graph execution.

Values

GRAPH_EXECUTION: Event from graph execution. NODE_EXECUTION: Event from node execution. TOOL_EXECUTION: Event from tool execution. STREAMING: Event from streaming updates.

Attributes:

Name Type Description
GRAPH_EXECUTION
NODE_EXECUTION
STREAMING
TOOL_EXECUTION
Source code in pyagenity/publisher/events.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class Event(str, enum.Enum):
    """Enum for event sources in agent graph execution.

    Values:
        GRAPH_EXECUTION: Event from graph execution.
        NODE_EXECUTION: Event from node execution.
        TOOL_EXECUTION: Event from tool execution.
        STREAMING: Event from streaming updates.
    """

    GRAPH_EXECUTION = "graph_execution"
    NODE_EXECUTION = "node_execution"
    TOOL_EXECUTION = "tool_execution"
    STREAMING = "streaming"
Attributes
GRAPH_EXECUTION class-attribute instance-attribute
GRAPH_EXECUTION = 'graph_execution'
NODE_EXECUTION class-attribute instance-attribute
NODE_EXECUTION = 'node_execution'
STREAMING class-attribute instance-attribute
STREAMING = 'streaming'
TOOL_EXECUTION class-attribute instance-attribute
TOOL_EXECUTION = 'tool_execution'
EventModel

Bases: BaseModel

Structured event chunk for streaming agent graph execution.

Represents a chunk of streamed data with event and content semantics, supporting both delta (incremental) and full content. Used for real-time streaming of execution updates, tool calls, state changes, messages, and errors.

Attributes:

Name Type Description
event Event

Type of the event source.

event_type EventType

Phase of the event (start, progress, end, update).

content str

Streamed textual content.

content_blocks list[ContentBlock] | None

Structured content blocks for multimodal streaming.

delta bool

True if this is a delta update (incremental).

delta_type Literal['text', 'json', 'binary'] | None

Type of delta when delta=True.

block_index int | None

Index of the content block this chunk applies to.

chunk_index int | None

Per-block chunk index for ordering.

byte_offset int | None

Byte offset for binary/media streaming.

data dict[str, Any]

Additional structured data.

content_type list[ContentType] | None

Semantic type of content.

sequence_id int

Monotonic sequence ID for stream ordering.

node_name str

Name of the node producing this chunk.

run_id str

Unique ID for this stream/run.

thread_id str | int

Thread ID for this execution.

timestamp float

UNIX timestamp of when chunk was created.

is_error bool

Marks this chunk as representing an error state.

metadata dict[str, Any]

Optional metadata for consumers.

Classes:

Name Description
Config

Pydantic configuration for EventModel.

Methods:

Name Description
default

Create a default EventModel instance with minimal required fields.

stream

Create a default EventModel instance for streaming updates.

Source code in pyagenity/publisher/events.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
class EventModel(BaseModel):
    """
    Structured event chunk for streaming agent graph execution.

    Represents a chunk of streamed data with event and content semantics, supporting both delta
    (incremental) and full content. Used for real-time streaming of execution updates, tool calls,
    state changes, messages, and errors.

    Attributes:
        event: Type of the event source.
        event_type: Phase of the event (start, progress, end, update).
        content: Streamed textual content.
        content_blocks: Structured content blocks for multimodal streaming.
        delta: True if this is a delta update (incremental).
        delta_type: Type of delta when delta=True.
        block_index: Index of the content block this chunk applies to.
        chunk_index: Per-block chunk index for ordering.
        byte_offset: Byte offset for binary/media streaming.
        data: Additional structured data.
        content_type: Semantic type of content.
        sequence_id: Monotonic sequence ID for stream ordering.
        node_name: Name of the node producing this chunk.
        run_id: Unique ID for this stream/run.
        thread_id: Thread ID for this execution.
        timestamp: UNIX timestamp of when chunk was created.
        is_error: Marks this chunk as representing an error state.
        metadata: Optional metadata for consumers.
    """

    # Event metadata
    event: Event = Field(..., description="Type of the event source")
    event_type: EventType = Field(
        ..., description="Phase of the event (start, progress, end, update)"
    )

    # Streamed content
    content: str = Field(default="", description="Streamed textual content")
    # Structured content blocks for multimodal/structured streaming
    content_blocks: list[ContentBlock] | None = Field(
        default=None, description="Structured content blocks carried by this event"
    )
    # Delta controls
    delta: bool = Field(default=False, description="True if this is a delta update (incremental)")
    delta_type: Literal["text", "json", "binary"] | None = Field(
        default=None, description="Type of delta when delta=True"
    )
    block_index: int | None = Field(
        default=None, description="Index of the content block this chunk applies to"
    )
    chunk_index: int | None = Field(default=None, description="Per-block chunk index for ordering")
    byte_offset: int | None = Field(
        default=None, description="Byte offset for binary/media streaming"
    )

    # Data payload
    data: dict[str, Any] = Field(default_factory=dict, description="Additional structured data")

    # Metadata
    content_type: list[ContentType] | None = Field(
        default=None, description="Semantic type of content"
    )
    sequence_id: int = Field(default=0, description="Monotonic sequence ID for stream ordering")
    node_name: str = Field(default="", description="Name of the node producing this chunk")
    run_id: str = Field(
        default_factory=lambda: str(uuid.uuid4()), description="Unique ID for this stream/run"
    )
    thread_id: str | int = Field(default="", description="Thread ID for this execution")
    timestamp: float = Field(
        default_factory=time.time, description="UNIX timestamp of when chunk was created"
    )
    is_error: bool = Field(
        default=False, description="Marks this chunk as representing an error state"
    )
    metadata: dict[str, Any] = Field(
        default_factory=dict, description="Optional metadata for consumers"
    )

    class Config:
        """Pydantic configuration for EventModel.

        Attributes:
            use_enum_values: Output enums as strings.
        """

        use_enum_values = True  # Output enums as strings

    @classmethod
    def default(
        cls,
        base_config: dict,
        data: dict[str, Any],
        content_type: list[ContentType],
        event: Event = Event.GRAPH_EXECUTION,
        event_type=EventType.START,
        node_name: str = "",
        extra: dict[str, Any] | None = None,
    ) -> "EventModel":
        """Create a default EventModel instance with minimal required fields.

        Args:
            base_config: Base configuration for the event (thread/run/timestamp/user).
            data: Structured data payload.
            content_type: Semantic type(s) of content.
            event: Event source type (default: GRAPH_EXECUTION).
            event_type: Event phase (default: START).
            node_name: Name of the node producing the event.
            extra: Additional metadata.

        Returns:
            EventModel: The created event model instance.
        """
        thread_id = base_config.get("thread_id", "")
        run_id = base_config.get("run_id", "")

        metadata = {
            "run_timestamp": base_config.get("timestamp", ""),
            "user_id": base_config.get("user_id"),
            "is_stream": base_config.get("is_stream", False),
        }
        if extra:
            metadata.update(extra)
        return cls(
            event=event,
            event_type=event_type,
            delta=False,
            content_type=content_type,
            data=data,
            thread_id=thread_id,
            node_name=node_name,
            run_id=run_id,
            metadata=metadata,
        )

    @classmethod
    def stream(
        cls,
        base_config: dict,
        node_name: str = "",
        extra: dict[str, Any] | None = None,
    ) -> "EventModel":
        """Create a default EventModel instance for streaming updates.

        Args:
            base_config: Base configuration for the event (thread/run/timestamp/user).
            node_name: Name of the node producing the event.
            extra: Additional metadata.

        Returns:
            EventModel: The created event model instance for streaming.
        """
        thread_id = base_config.get("thread_id", "")
        run_id = base_config.get("run_id", "")

        metadata = {
            "run_timestamp": base_config.get("timestamp", ""),
            "user_id": base_config.get("user_id"),
            "is_stream": base_config.get("is_stream", False),
        }
        if extra:
            metadata.update(extra)
        return cls(
            event=Event.STREAMING,
            event_type=EventType.UPDATE,
            delta=True,
            content_type=[ContentType.TEXT, ContentType.REASONING],
            data={},
            thread_id=thread_id,
            node_name=node_name,
            run_id=run_id,
            metadata=metadata,
        )
Attributes
block_index class-attribute instance-attribute
block_index = Field(default=None, description='Index of the content block this chunk applies to')
byte_offset class-attribute instance-attribute
byte_offset = Field(default=None, description='Byte offset for binary/media streaming')
chunk_index class-attribute instance-attribute
chunk_index = Field(default=None, description='Per-block chunk index for ordering')
content class-attribute instance-attribute
content = Field(default='', description='Streamed textual content')
content_blocks class-attribute instance-attribute
content_blocks = Field(default=None, description='Structured content blocks carried by this event')
content_type class-attribute instance-attribute
content_type = Field(default=None, description='Semantic type of content')
data class-attribute instance-attribute
data = Field(default_factory=dict, description='Additional structured data')
delta class-attribute instance-attribute
delta = Field(default=False, description='True if this is a delta update (incremental)')
delta_type class-attribute instance-attribute
delta_type = Field(default=None, description='Type of delta when delta=True')
event class-attribute instance-attribute
event = Field(..., description='Type of the event source')
event_type class-attribute instance-attribute
event_type = Field(..., description='Phase of the event (start, progress, end, update)')
is_error class-attribute instance-attribute
is_error = Field(default=False, description='Marks this chunk as representing an error state')
metadata class-attribute instance-attribute
metadata = Field(default_factory=dict, description='Optional metadata for consumers')
node_name class-attribute instance-attribute
node_name = Field(default='', description='Name of the node producing this chunk')
run_id class-attribute instance-attribute
run_id = Field(default_factory=lambda: str(uuid4()), description='Unique ID for this stream/run')
sequence_id class-attribute instance-attribute
sequence_id = Field(default=0, description='Monotonic sequence ID for stream ordering')
thread_id class-attribute instance-attribute
thread_id = Field(default='', description='Thread ID for this execution')
timestamp class-attribute instance-attribute
timestamp = Field(default_factory=time, description='UNIX timestamp of when chunk was created')
Classes
Config

Pydantic configuration for EventModel.

Attributes:

Name Type Description
use_enum_values

Output enums as strings.

Source code in pyagenity/publisher/events.py
173
174
175
176
177
178
179
180
class Config:
    """Pydantic configuration for EventModel.

    Attributes:
        use_enum_values: Output enums as strings.
    """

    use_enum_values = True  # Output enums as strings
Attributes
use_enum_values class-attribute instance-attribute
use_enum_values = True
Functions
default classmethod
default(base_config, data, content_type, event=Event.GRAPH_EXECUTION, event_type=EventType.START, node_name='', extra=None)

Create a default EventModel instance with minimal required fields.

Parameters:

Name Type Description Default
base_config dict

Base configuration for the event (thread/run/timestamp/user).

required
data dict[str, Any]

Structured data payload.

required
content_type list[ContentType]

Semantic type(s) of content.

required
event Event

Event source type (default: GRAPH_EXECUTION).

GRAPH_EXECUTION
event_type

Event phase (default: START).

START
node_name str

Name of the node producing the event.

''
extra dict[str, Any] | None

Additional metadata.

None

Returns:

Name Type Description
EventModel EventModel

The created event model instance.

Source code in pyagenity/publisher/events.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
@classmethod
def default(
    cls,
    base_config: dict,
    data: dict[str, Any],
    content_type: list[ContentType],
    event: Event = Event.GRAPH_EXECUTION,
    event_type=EventType.START,
    node_name: str = "",
    extra: dict[str, Any] | None = None,
) -> "EventModel":
    """Create a default EventModel instance with minimal required fields.

    Args:
        base_config: Base configuration for the event (thread/run/timestamp/user).
        data: Structured data payload.
        content_type: Semantic type(s) of content.
        event: Event source type (default: GRAPH_EXECUTION).
        event_type: Event phase (default: START).
        node_name: Name of the node producing the event.
        extra: Additional metadata.

    Returns:
        EventModel: The created event model instance.
    """
    thread_id = base_config.get("thread_id", "")
    run_id = base_config.get("run_id", "")

    metadata = {
        "run_timestamp": base_config.get("timestamp", ""),
        "user_id": base_config.get("user_id"),
        "is_stream": base_config.get("is_stream", False),
    }
    if extra:
        metadata.update(extra)
    return cls(
        event=event,
        event_type=event_type,
        delta=False,
        content_type=content_type,
        data=data,
        thread_id=thread_id,
        node_name=node_name,
        run_id=run_id,
        metadata=metadata,
    )
stream classmethod
stream(base_config, node_name='', extra=None)

Create a default EventModel instance for streaming updates.

Parameters:

Name Type Description Default
base_config dict

Base configuration for the event (thread/run/timestamp/user).

required
node_name str

Name of the node producing the event.

''
extra dict[str, Any] | None

Additional metadata.

None

Returns:

Name Type Description
EventModel EventModel

The created event model instance for streaming.

Source code in pyagenity/publisher/events.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
@classmethod
def stream(
    cls,
    base_config: dict,
    node_name: str = "",
    extra: dict[str, Any] | None = None,
) -> "EventModel":
    """Create a default EventModel instance for streaming updates.

    Args:
        base_config: Base configuration for the event (thread/run/timestamp/user).
        node_name: Name of the node producing the event.
        extra: Additional metadata.

    Returns:
        EventModel: The created event model instance for streaming.
    """
    thread_id = base_config.get("thread_id", "")
    run_id = base_config.get("run_id", "")

    metadata = {
        "run_timestamp": base_config.get("timestamp", ""),
        "user_id": base_config.get("user_id"),
        "is_stream": base_config.get("is_stream", False),
    }
    if extra:
        metadata.update(extra)
    return cls(
        event=Event.STREAMING,
        event_type=EventType.UPDATE,
        delta=True,
        content_type=[ContentType.TEXT, ContentType.REASONING],
        data={},
        thread_id=thread_id,
        node_name=node_name,
        run_id=run_id,
        metadata=metadata,
    )
EventType

Bases: str, Enum

Enum for event phases in agent graph execution.

Values

START: Event marks start of execution. PROGRESS: Event marks progress update. RESULT: Event marks result produced. END: Event marks end of execution. UPDATE: Event marks update. ERROR: Event marks error. INTERRUPTED: Event marks interruption.

Attributes:

Name Type Description
END
ERROR
INTERRUPTED
PROGRESS
RESULT
START
UPDATE
Source code in pyagenity/publisher/events.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class EventType(str, enum.Enum):
    """Enum for event phases in agent graph execution.

    Values:
        START: Event marks start of execution.
        PROGRESS: Event marks progress update.
        RESULT: Event marks result produced.
        END: Event marks end of execution.
        UPDATE: Event marks update.
        ERROR: Event marks error.
        INTERRUPTED: Event marks interruption.
    """

    START = "start"
    PROGRESS = "progress"
    RESULT = "result"
    END = "end"
    UPDATE = "update"
    ERROR = "error"
    INTERRUPTED = "interrupted"
Attributes
END class-attribute instance-attribute
END = 'end'
ERROR class-attribute instance-attribute
ERROR = 'error'
INTERRUPTED class-attribute instance-attribute
INTERRUPTED = 'interrupted'
PROGRESS class-attribute instance-attribute
PROGRESS = 'progress'
RESULT class-attribute instance-attribute
RESULT = 'result'
START class-attribute instance-attribute
START = 'start'
UPDATE class-attribute instance-attribute
UPDATE = 'update'
kafka_publisher

Kafka publisher implementation (optional dependency).

Uses aiokafka to publish events to a Kafka topic.

Dependency: aiokafka Not installed by default; install extra: pip install pyagenity[kafka].

Classes:

Name Description
KafkaPublisher

Publish events to a Kafka topic using aiokafka.

Attributes:

Name Type Description
logger
Attributes
logger module-attribute
logger = getLogger(__name__)
Classes
KafkaPublisher

Bases: BasePublisher

Publish events to a Kafka topic using aiokafka.

This class provides an asynchronous interface for publishing events to a Kafka topic. It uses the aiokafka library to handle the producer operations. The publisher is lazily initialized and can be reused for multiple publishes.

Attributes:

Name Type Description
bootstrap_servers str

Kafka bootstrap servers.

topic str

Kafka topic to publish to.

client_id str | None

Client ID for the producer.

_producer

Lazy-loaded Kafka producer instance.

Methods:

Name Description
__init__

Initialize the KafkaPublisher.

close

Close the Kafka producer.

publish

Publish an event to the Kafka topic.

sync_close

Synchronously close the Kafka producer.

Source code in pyagenity/publisher/kafka_publisher.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
class KafkaPublisher(BasePublisher):
    """Publish events to a Kafka topic using aiokafka.

    This class provides an asynchronous interface for publishing events to a Kafka topic.
    It uses the aiokafka library to handle the producer operations. The publisher is
    lazily initialized and can be reused for multiple publishes.

    Attributes:
        bootstrap_servers: Kafka bootstrap servers.
        topic: Kafka topic to publish to.
        client_id: Client ID for the producer.
        _producer: Lazy-loaded Kafka producer instance.
    """

    def __init__(self, config: dict[str, Any] | None = None):
        """Initialize the KafkaPublisher.

        Args:
            config: Configuration dictionary. Supported keys:
                - bootstrap_servers: Kafka bootstrap servers (default: "localhost:9092").
                - topic: Kafka topic to publish to (default: "pyagenity.events").
                - client_id: Client ID for the producer.
        """
        super().__init__(config or {})
        self.bootstrap_servers: str = self.config.get("bootstrap_servers", "localhost:9092")
        self.topic: str = self.config.get("topic", "pyagenity.events")
        self.client_id: str | None = self.config.get("client_id")
        self._producer = None  # type: ignore[var-annotated]

    async def _get_producer(self):
        """Get or create the Kafka producer instance.

        This method lazily initializes the producer if it hasn't been created yet.
        It imports aiokafka and starts the producer.

        Returns:
            The initialized producer instance.

        Raises:
            RuntimeError: If the 'aiokafka' package is not installed.
        """
        if self._producer is not None:
            return self._producer

        try:
            aiokafka = importlib.import_module("aiokafka")
        except Exception as exc:
            raise RuntimeError(
                "KafkaPublisher requires the 'aiokafka' package. Install with "
                "'pip install pyagenity[kafka]' or 'pip install aiokafka'."
            ) from exc

        producer_cls = aiokafka.AIOKafkaProducer
        self._producer = producer_cls(
            bootstrap_servers=self.bootstrap_servers,
            client_id=self.client_id,
        )
        await self._producer.start()
        return self._producer

    async def publish(self, event: EventModel) -> Any:
        """Publish an event to the Kafka topic.

        Args:
            event: The event to publish.

        Returns:
            The result of the send_and_wait operation.
        """
        producer = await self._get_producer()
        payload = json.dumps(event.model_dump()).encode("utf-8")
        return await producer.send_and_wait(self.topic, payload)

    async def close(self):
        """Close the Kafka producer.

        Stops the producer and cleans up resources. Errors during stopping are logged
        but do not raise exceptions.
        """
        if self._producer is None:
            return

        try:
            await self._producer.stop()
        except Exception:
            logger.debug("KafkaPublisher close encountered an error", exc_info=True)
        finally:
            self._producer = None

    def sync_close(self):
        """Synchronously close the Kafka producer.

        This method runs the async close in a new event loop. If called within an
        active event loop, it logs a warning and skips the operation.
        """
        try:
            asyncio.run(self.close())
        except RuntimeError:
            logger.warning("sync_close called within an active event loop; skipping.")
Attributes
bootstrap_servers instance-attribute
bootstrap_servers = get('bootstrap_servers', 'localhost:9092')
client_id instance-attribute
client_id = get('client_id')
config instance-attribute
config = config
topic instance-attribute
topic = get('topic', 'pyagenity.events')
Functions
__init__
__init__(config=None)

Initialize the KafkaPublisher.

Parameters:

Name Type Description Default
config dict[str, Any] | None

Configuration dictionary. Supported keys: - bootstrap_servers: Kafka bootstrap servers (default: "localhost:9092"). - topic: Kafka topic to publish to (default: "pyagenity.events"). - client_id: Client ID for the producer.

None
Source code in pyagenity/publisher/kafka_publisher.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def __init__(self, config: dict[str, Any] | None = None):
    """Initialize the KafkaPublisher.

    Args:
        config: Configuration dictionary. Supported keys:
            - bootstrap_servers: Kafka bootstrap servers (default: "localhost:9092").
            - topic: Kafka topic to publish to (default: "pyagenity.events").
            - client_id: Client ID for the producer.
    """
    super().__init__(config or {})
    self.bootstrap_servers: str = self.config.get("bootstrap_servers", "localhost:9092")
    self.topic: str = self.config.get("topic", "pyagenity.events")
    self.client_id: str | None = self.config.get("client_id")
    self._producer = None  # type: ignore[var-annotated]
close async
close()

Close the Kafka producer.

Stops the producer and cleans up resources. Errors during stopping are logged but do not raise exceptions.

Source code in pyagenity/publisher/kafka_publisher.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
async def close(self):
    """Close the Kafka producer.

    Stops the producer and cleans up resources. Errors during stopping are logged
    but do not raise exceptions.
    """
    if self._producer is None:
        return

    try:
        await self._producer.stop()
    except Exception:
        logger.debug("KafkaPublisher close encountered an error", exc_info=True)
    finally:
        self._producer = None
publish async
publish(event)

Publish an event to the Kafka topic.

Parameters:

Name Type Description Default
event EventModel

The event to publish.

required

Returns:

Type Description
Any

The result of the send_and_wait operation.

Source code in pyagenity/publisher/kafka_publisher.py
84
85
86
87
88
89
90
91
92
93
94
95
async def publish(self, event: EventModel) -> Any:
    """Publish an event to the Kafka topic.

    Args:
        event: The event to publish.

    Returns:
        The result of the send_and_wait operation.
    """
    producer = await self._get_producer()
    payload = json.dumps(event.model_dump()).encode("utf-8")
    return await producer.send_and_wait(self.topic, payload)
sync_close
sync_close()

Synchronously close the Kafka producer.

This method runs the async close in a new event loop. If called within an active event loop, it logs a warning and skips the operation.

Source code in pyagenity/publisher/kafka_publisher.py
113
114
115
116
117
118
119
120
121
122
def sync_close(self):
    """Synchronously close the Kafka producer.

    This method runs the async close in a new event loop. If called within an
    active event loop, it logs a warning and skips the operation.
    """
    try:
        asyncio.run(self.close())
    except RuntimeError:
        logger.warning("sync_close called within an active event loop; skipping.")
publish

Functions:

Name Description
publish_event

Publish an event asynchronously using the background task manager.

Attributes:

Name Type Description
logger
Attributes
logger module-attribute
logger = getLogger(__name__)
Classes
Functions
publish_event
publish_event(event, publisher=Inject[BasePublisher], task_manager=Inject[BackgroundTaskManager])

Publish an event asynchronously using the background task manager.

Parameters:

Name Type Description Default
event EventModel

The event to publish.

required
publisher BasePublisher | None

The publisher instance (injected).

Inject[BasePublisher]
task_manager BackgroundTaskManager

The background task manager (injected).

Inject[BackgroundTaskManager]
Source code in pyagenity/publisher/publish.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def publish_event(
    event: EventModel,
    publisher: BasePublisher | None = Inject[BasePublisher],
    task_manager: BackgroundTaskManager = Inject[BackgroundTaskManager],
) -> None:
    """Publish an event asynchronously using the background task manager.

    Args:
        event: The event to publish.
        publisher: The publisher instance (injected).
        task_manager: The background task manager (injected).
    """
    # Store the task to prevent it from being garbage collected
    task_manager.create_task(_publish_event_task(event, publisher))
rabbitmq_publisher

RabbitMQ publisher implementation (optional dependency).

Uses aio-pika to publish events to an exchange with a routing key.

Dependency: aio-pika Not installed by default; install extra: pip install pyagenity[rabbitmq].

Classes:

Name Description
RabbitMQPublisher

Publish events to RabbitMQ using aio-pika.

Attributes:

Name Type Description
logger
Attributes
logger module-attribute
logger = getLogger(__name__)
Classes
RabbitMQPublisher

Bases: BasePublisher

Publish events to RabbitMQ using aio-pika.

Attributes:

Name Type Description
url str

RabbitMQ connection URL.

exchange str

Exchange name.

routing_key str

Routing key for messages.

exchange_type str

Type of exchange.

declare bool

Whether to declare the exchange.

durable bool

Whether the exchange is durable.

_conn

Connection instance.

_channel

Channel instance.

_exchange

Exchange instance.

Methods:

Name Description
__init__

Initialize the RabbitMQPublisher.

close

Close the RabbitMQ connection and channel.

publish

Publish an event to RabbitMQ.

sync_close

Synchronously close the RabbitMQ connection.

Source code in pyagenity/publisher/rabbitmq_publisher.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
class RabbitMQPublisher(BasePublisher):
    """Publish events to RabbitMQ using aio-pika.

    Attributes:
        url: RabbitMQ connection URL.
        exchange: Exchange name.
        routing_key: Routing key for messages.
        exchange_type: Type of exchange.
        declare: Whether to declare the exchange.
        durable: Whether the exchange is durable.
        _conn: Connection instance.
        _channel: Channel instance.
        _exchange: Exchange instance.
    """

    def __init__(self, config: dict[str, Any] | None = None):
        """Initialize the RabbitMQPublisher.

        Args:
            config: Configuration dictionary. Supported keys:
                - url: RabbitMQ URL (default: "amqp://guest:guest@localhost/").
                - exchange: Exchange name (default: "pyagenity.events").
                - routing_key: Routing key (default: "pyagenity.events").
                - exchange_type: Exchange type (default: "topic").
                - declare: Whether to declare exchange (default: True).
                - durable: Whether exchange is durable (default: True).
        """
        super().__init__(config or {})
        self.url: str = self.config.get("url", "amqp://guest:guest@localhost/")
        self.exchange: str = self.config.get("exchange", "pyagenity.events")
        self.routing_key: str = self.config.get("routing_key", "pyagenity.events")
        self.exchange_type: str = self.config.get("exchange_type", "topic")
        self.declare: bool = self.config.get("declare", True)
        self.durable: bool = self.config.get("durable", True)

        self._conn = None  # type: ignore[var-annotated]
        self._channel = None  # type: ignore[var-annotated]
        self._exchange = None  # type: ignore[var-annotated]

    async def _ensure(self):
        """Ensure the connection, channel, and exchange are initialized."""
        if self._exchange is not None:
            return

        try:
            aio_pika = importlib.import_module("aio_pika")
        except Exception as exc:
            raise RuntimeError(
                "RabbitMQPublisher requires the 'aio-pika' package. Install with "
                "'pip install pyagenity[rabbitmq]' or 'pip install aio-pika'."
            ) from exc

        # Connect and declare exchange if needed
        self._conn = await aio_pika.connect_robust(self.url)
        self._channel = await self._conn.channel()

        if self.declare:
            ex_type = getattr(
                aio_pika.ExchangeType,
                self.exchange_type.upper(),
                aio_pika.ExchangeType.TOPIC,
            )
            self._exchange = await self._channel.declare_exchange(
                self.exchange, ex_type, durable=self.durable
            )
        else:
            # Fall back to default exchange
            self._exchange = self._channel.default_exchange

    async def publish(self, event: EventModel) -> Any:
        """Publish an event to RabbitMQ.

        Args:
            event: The event to publish.

        Returns:
            True on success.
        """
        await self._ensure()
        payload = json.dumps(event.model_dump()).encode("utf-8")

        aio_pika = importlib.import_module("aio_pika")
        message = aio_pika.Message(body=payload)
        if self._exchange is None:
            raise RuntimeError("RabbitMQPublisher exchange not initialized")
        await self._exchange.publish(message, routing_key=self.routing_key)
        return True

    async def close(self):
        """Close the RabbitMQ connection and channel."""
        try:
            if self._channel is not None:
                await self._channel.close()
        except Exception:
            logger.debug("RabbitMQPublisher channel close error", exc_info=True)
        finally:
            self._channel = None

        try:
            if self._conn is not None:
                await self._conn.close()
        except Exception:
            logger.debug("RabbitMQPublisher connection close error", exc_info=True)
        finally:
            self._conn = None
            self._exchange = None

    def sync_close(self):
        """Synchronously close the RabbitMQ connection."""
        try:
            asyncio.run(self.close())
        except RuntimeError:
            logger.warning("sync_close called within an active event loop; skipping.")
Attributes
config instance-attribute
config = config
declare instance-attribute
declare = get('declare', True)
durable instance-attribute
durable = get('durable', True)
exchange instance-attribute
exchange = get('exchange', 'pyagenity.events')
exchange_type instance-attribute
exchange_type = get('exchange_type', 'topic')
routing_key instance-attribute
routing_key = get('routing_key', 'pyagenity.events')
url instance-attribute
url = get('url', 'amqp://guest:guest@localhost/')
Functions
__init__
__init__(config=None)

Initialize the RabbitMQPublisher.

Parameters:

Name Type Description Default
config dict[str, Any] | None

Configuration dictionary. Supported keys: - url: RabbitMQ URL (default: "amqp://guest:guest@localhost/"). - exchange: Exchange name (default: "pyagenity.events"). - routing_key: Routing key (default: "pyagenity.events"). - exchange_type: Exchange type (default: "topic"). - declare: Whether to declare exchange (default: True). - durable: Whether exchange is durable (default: True).

None
Source code in pyagenity/publisher/rabbitmq_publisher.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def __init__(self, config: dict[str, Any] | None = None):
    """Initialize the RabbitMQPublisher.

    Args:
        config: Configuration dictionary. Supported keys:
            - url: RabbitMQ URL (default: "amqp://guest:guest@localhost/").
            - exchange: Exchange name (default: "pyagenity.events").
            - routing_key: Routing key (default: "pyagenity.events").
            - exchange_type: Exchange type (default: "topic").
            - declare: Whether to declare exchange (default: True).
            - durable: Whether exchange is durable (default: True).
    """
    super().__init__(config or {})
    self.url: str = self.config.get("url", "amqp://guest:guest@localhost/")
    self.exchange: str = self.config.get("exchange", "pyagenity.events")
    self.routing_key: str = self.config.get("routing_key", "pyagenity.events")
    self.exchange_type: str = self.config.get("exchange_type", "topic")
    self.declare: bool = self.config.get("declare", True)
    self.durable: bool = self.config.get("durable", True)

    self._conn = None  # type: ignore[var-annotated]
    self._channel = None  # type: ignore[var-annotated]
    self._exchange = None  # type: ignore[var-annotated]
close async
close()

Close the RabbitMQ connection and channel.

Source code in pyagenity/publisher/rabbitmq_publisher.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
async def close(self):
    """Close the RabbitMQ connection and channel."""
    try:
        if self._channel is not None:
            await self._channel.close()
    except Exception:
        logger.debug("RabbitMQPublisher channel close error", exc_info=True)
    finally:
        self._channel = None

    try:
        if self._conn is not None:
            await self._conn.close()
    except Exception:
        logger.debug("RabbitMQPublisher connection close error", exc_info=True)
    finally:
        self._conn = None
        self._exchange = None
publish async
publish(event)

Publish an event to RabbitMQ.

Parameters:

Name Type Description Default
event EventModel

The event to publish.

required

Returns:

Type Description
Any

True on success.

Source code in pyagenity/publisher/rabbitmq_publisher.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
async def publish(self, event: EventModel) -> Any:
    """Publish an event to RabbitMQ.

    Args:
        event: The event to publish.

    Returns:
        True on success.
    """
    await self._ensure()
    payload = json.dumps(event.model_dump()).encode("utf-8")

    aio_pika = importlib.import_module("aio_pika")
    message = aio_pika.Message(body=payload)
    if self._exchange is None:
        raise RuntimeError("RabbitMQPublisher exchange not initialized")
    await self._exchange.publish(message, routing_key=self.routing_key)
    return True
sync_close
sync_close()

Synchronously close the RabbitMQ connection.

Source code in pyagenity/publisher/rabbitmq_publisher.py
131
132
133
134
135
136
def sync_close(self):
    """Synchronously close the RabbitMQ connection."""
    try:
        asyncio.run(self.close())
    except RuntimeError:
        logger.warning("sync_close called within an active event loop; skipping.")
redis_publisher

Redis publisher implementation (optional dependency).

This publisher uses the redis-py asyncio client to publish events via: - Pub/Sub channels (default), or - Redis Streams (XADD) when configured with mode="stream".

Dependency: redis>=4.2 (provides redis.asyncio). Not installed by default; install extra: pip install pyagenity[redis].

Classes:

Name Description
RedisPublisher

Publish events to Redis via Pub/Sub channel or Stream.

Attributes:

Name Type Description
logger
Attributes
logger module-attribute
logger = getLogger(__name__)
Classes
RedisPublisher

Bases: BasePublisher

Publish events to Redis via Pub/Sub channel or Stream.

Attributes:

Name Type Description
url str

Redis URL.

mode str

Publishing mode ('pubsub' or 'stream').

channel str

Pub/Sub channel name.

stream str

Stream name.

maxlen int | None

Max length for streams.

encoding str

Encoding for messages.

_redis

Redis client instance.

Methods:

Name Description
__init__

Initialize the RedisPublisher.

close

Close the Redis client.

publish

Publish an event to Redis.

sync_close

Synchronously close the Redis client.

Source code in pyagenity/publisher/redis_publisher.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
class RedisPublisher(BasePublisher):
    """Publish events to Redis via Pub/Sub channel or Stream.

    Attributes:
        url: Redis URL.
        mode: Publishing mode ('pubsub' or 'stream').
        channel: Pub/Sub channel name.
        stream: Stream name.
        maxlen: Max length for streams.
        encoding: Encoding for messages.
        _redis: Redis client instance.
    """

    def __init__(self, config: dict[str, Any] | None = None):
        """Initialize the RedisPublisher.

        Args:
            config: Configuration dictionary. Supported keys:
                - url: Redis URL (default: "redis://localhost:6379/0").
                - mode: Publishing mode ('pubsub' or 'stream', default: 'pubsub').
                - channel: Pub/Sub channel name (default: "pyagenity.events").
                - stream: Stream name (default: "pyagenity.events").
                - maxlen: Max length for streams.
                - encoding: Encoding (default: "utf-8").
        """
        super().__init__(config or {})
        self.url: str = self.config.get("url", "redis://localhost:6379/0")
        self.mode: str = self.config.get("mode", "pubsub")
        self.channel: str = self.config.get("channel", "pyagenity.events")
        self.stream: str = self.config.get("stream", "pyagenity.events")
        self.maxlen: int | None = self.config.get("maxlen")
        self.encoding: str = self.config.get("encoding", "utf-8")

        # Lazy import & connect on first use to avoid ImportError at import-time.
        self._redis = None  # type: ignore[var-annotated]

    async def _get_client(self):
        """Get or create the Redis client.

        Returns:
            The Redis client instance.

        Raises:
            RuntimeError: If connection fails.
        """
        if self._redis is not None:
            return self._redis

        try:
            redis_asyncio = importlib.import_module("redis.asyncio")
        except Exception as exc:  # ImportError and others
            raise RuntimeError(
                "RedisPublisher requires the 'redis' package. Install with "
                "'pip install pyagenity[redis]' or 'pip install redis'."
            ) from exc

        try:
            self._redis = redis_asyncio.from_url(
                self.url, encoding=self.encoding, decode_responses=False
            )
        except Exception as exc:
            raise RuntimeError(f"RedisPublisher failed to connect to Redis at {self.url}") from exc

        return self._redis

    async def publish(self, event: EventModel) -> Any:
        """Publish an event to Redis.

        Args:
            event: The event to publish.

        Returns:
            The result of the publish operation.
        """
        client = await self._get_client()
        payload = json.dumps(event.model_dump()).encode(self.encoding)

        if self.mode == "stream":
            # XADD to stream
            fields = {"data": payload}
            if self.maxlen is not None:
                return await client.xadd(self.stream, fields, maxlen=self.maxlen, approximate=True)
            return await client.xadd(self.stream, fields)

        # Default: Pub/Sub channel
        return await client.publish(self.channel, payload)

    async def close(self):
        """Close the Redis client."""
        if self._redis is not None:
            try:
                await self._redis.close()
                await self._redis.connection_pool.disconnect(inuse_connections=True)
            except Exception:  # best-effort close
                logger.debug("RedisPublisher close encountered an error", exc_info=True)
            finally:
                self._redis = None

    def sync_close(self):
        """Synchronously close the Redis client."""
        try:
            asyncio.run(self.close())
        except RuntimeError:
            # Already in an event loop; fall back to scheduling close
            logger.warning("sync_close called within an active event loop; skipping.")
Attributes
channel instance-attribute
channel = get('channel', 'pyagenity.events')
config instance-attribute
config = config
encoding instance-attribute
encoding = get('encoding', 'utf-8')
maxlen instance-attribute
maxlen = get('maxlen')
mode instance-attribute
mode = get('mode', 'pubsub')
stream instance-attribute
stream = get('stream', 'pyagenity.events')
url instance-attribute
url = get('url', 'redis://localhost:6379/0')
Functions
__init__
__init__(config=None)

Initialize the RedisPublisher.

Parameters:

Name Type Description Default
config dict[str, Any] | None

Configuration dictionary. Supported keys: - url: Redis URL (default: "redis://localhost:6379/0"). - mode: Publishing mode ('pubsub' or 'stream', default: 'pubsub'). - channel: Pub/Sub channel name (default: "pyagenity.events"). - stream: Stream name (default: "pyagenity.events"). - maxlen: Max length for streams. - encoding: Encoding (default: "utf-8").

None
Source code in pyagenity/publisher/redis_publisher.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def __init__(self, config: dict[str, Any] | None = None):
    """Initialize the RedisPublisher.

    Args:
        config: Configuration dictionary. Supported keys:
            - url: Redis URL (default: "redis://localhost:6379/0").
            - mode: Publishing mode ('pubsub' or 'stream', default: 'pubsub').
            - channel: Pub/Sub channel name (default: "pyagenity.events").
            - stream: Stream name (default: "pyagenity.events").
            - maxlen: Max length for streams.
            - encoding: Encoding (default: "utf-8").
    """
    super().__init__(config or {})
    self.url: str = self.config.get("url", "redis://localhost:6379/0")
    self.mode: str = self.config.get("mode", "pubsub")
    self.channel: str = self.config.get("channel", "pyagenity.events")
    self.stream: str = self.config.get("stream", "pyagenity.events")
    self.maxlen: int | None = self.config.get("maxlen")
    self.encoding: str = self.config.get("encoding", "utf-8")

    # Lazy import & connect on first use to avoid ImportError at import-time.
    self._redis = None  # type: ignore[var-annotated]
close async
close()

Close the Redis client.

Source code in pyagenity/publisher/redis_publisher.py
114
115
116
117
118
119
120
121
122
123
async def close(self):
    """Close the Redis client."""
    if self._redis is not None:
        try:
            await self._redis.close()
            await self._redis.connection_pool.disconnect(inuse_connections=True)
        except Exception:  # best-effort close
            logger.debug("RedisPublisher close encountered an error", exc_info=True)
        finally:
            self._redis = None
publish async
publish(event)

Publish an event to Redis.

Parameters:

Name Type Description Default
event EventModel

The event to publish.

required

Returns:

Type Description
Any

The result of the publish operation.

Source code in pyagenity/publisher/redis_publisher.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
async def publish(self, event: EventModel) -> Any:
    """Publish an event to Redis.

    Args:
        event: The event to publish.

    Returns:
        The result of the publish operation.
    """
    client = await self._get_client()
    payload = json.dumps(event.model_dump()).encode(self.encoding)

    if self.mode == "stream":
        # XADD to stream
        fields = {"data": payload}
        if self.maxlen is not None:
            return await client.xadd(self.stream, fields, maxlen=self.maxlen, approximate=True)
        return await client.xadd(self.stream, fields)

    # Default: Pub/Sub channel
    return await client.publish(self.channel, payload)
sync_close
sync_close()

Synchronously close the Redis client.

Source code in pyagenity/publisher/redis_publisher.py
125
126
127
128
129
130
131
def sync_close(self):
    """Synchronously close the Redis client."""
    try:
        asyncio.run(self.close())
    except RuntimeError:
        # Already in an event loop; fall back to scheduling close
        logger.warning("sync_close called within an active event loop; skipping.")

state

State management for PyAgenity agent graphs.

This package provides schemas and context managers for agent state, execution tracking, and message context management. All core state classes are exported for use in agent workflows and custom state extensions.

Modules:

Name Description
agent_state

Agent state schema for PyAgenity agent graphs.

base_context

Abstract base class for context management in PyAgenity agent graphs.

execution_state

Execution state management for graph execution in PyAgenity.

message_context_manager

Message context management for agent state in PyAgenity.

Classes:

Name Description
AgentState

Common state schema that includes messages, context and internal execution metadata.

BaseContextManager

Abstract base class for context management in AI interactions.

ExecutionState

Tracks the internal execution state of a graph.

ExecutionStatus

Status of graph execution.

MessageContextManager

Manages the context field for AI interactions.

Attributes

__all__ module-attribute
__all__ = ['AgentState', 'BaseContextManager', 'ExecutionState', 'ExecutionStatus', 'MessageContextManager']

Classes

AgentState

Bases: BaseModel

Common state schema that includes messages, context and internal execution metadata.

This class can be subclassed to add application-specific fields while maintaining compatibility with the PyAgenity framework. All internal execution metadata is preserved through subclassing.

Notes: - execution_meta contains internal-only execution progress and interrupt info. - Users may subclass AgentState to add application fields; internal exec meta remains available to the runtime and will be persisted with the state. - When subclassing, add your fields but keep the core fields intact.

Example

class MyCustomState(AgentState): user_data: dict = Field(default_factory=dict) custom_field: str = "default"

Methods:

Name Description
advance_step

Advance the execution step in the metadata.

clear_interrupt

Clear any interrupt in the execution metadata.

complete

Mark the agent state as completed.

error

Mark the agent state as errored.

is_interrupted

Check if the agent state is currently interrupted.

is_running

Check if the agent state is currently running.

is_stopped_requested

Check if a stop has been requested for the agent state.

set_current_node

Set the current node in the execution metadata.

set_interrupt

Set an interrupt in the execution metadata.

Attributes:

Name Type Description
context Annotated[list[Message], add_messages]
context_summary str | None
execution_meta ExecutionState
Source code in pyagenity/state/agent_state.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
class AgentState(BaseModel):
    """Common state schema that includes messages, context and internal execution metadata.

    This class can be subclassed to add application-specific fields while maintaining
    compatibility with the PyAgenity framework. All internal execution metadata
    is preserved through subclassing.

    Notes:
    - `execution_meta` contains internal-only execution progress and interrupt info.
    - Users may subclass `AgentState` to add application fields; internal exec meta remains
      available to the runtime and will be persisted with the state.
    - When subclassing, add your fields but keep the core fields intact.

    Example:
        class MyCustomState(AgentState):
            user_data: dict = Field(default_factory=dict)
            custom_field: str = "default"
    """

    context: Annotated[list[Message], add_messages] = Field(default_factory=list)
    context_summary: str | None = None
    # Internal execution metadata (kept private-ish but accessible to runtime)
    execution_meta: ExecMeta = Field(default_factory=lambda: ExecMeta(current_node=START))

    # Convenience delegation methods for execution meta so callers can use the same API
    def set_interrupt(self, node: str, reason: str, status, data: dict | None = None) -> None:
        """
        Set an interrupt in the execution metadata.

        Args:
            node (str): Node where the interrupt occurred.
            reason (str): Reason for the interrupt.
            status: Execution status to set.
            data (dict | None): Optional additional interrupt data.
        """
        logger.debug("Setting interrupt at node '%s' with reason: %s", node, reason)
        self.execution_meta.set_interrupt(node, reason, status, data)

    def clear_interrupt(self) -> None:
        """
        Clear any interrupt in the execution metadata.
        """
        logger.debug("Clearing interrupt")
        self.execution_meta.clear_interrupt()

    def is_running(self) -> bool:
        """
        Check if the agent state is currently running.

        Returns:
            bool: True if running, False otherwise.
        """
        running = self.execution_meta.is_running()
        logger.debug("State is_running: %s", running)
        return running

    def is_interrupted(self) -> bool:
        """
        Check if the agent state is currently interrupted.

        Returns:
            bool: True if interrupted, False otherwise.
        """
        interrupted = self.execution_meta.is_interrupted()
        logger.debug("State is_interrupted: %s", interrupted)
        return interrupted

    def advance_step(self) -> None:
        """
        Advance the execution step in the metadata.
        """
        old_step = self.execution_meta.step
        self.execution_meta.advance_step()
        logger.debug("Advanced step from %d to %d", old_step, self.execution_meta.step)

    def set_current_node(self, node: str) -> None:
        """
        Set the current node in the execution metadata.

        Args:
            node (str): Node to set as current.
        """
        old_node = self.execution_meta.current_node
        self.execution_meta.set_current_node(node)
        logger.debug("Changed current node from '%s' to '%s'", old_node, node)

    def complete(self) -> None:
        """
        Mark the agent state as completed.
        """
        logger.info("Marking state as completed")
        self.execution_meta.complete()

    def error(self, error_msg: str) -> None:
        """
        Mark the agent state as errored.

        Args:
            error_msg (str): Error message to record.
        """
        logger.error("Setting state error: %s", error_msg)
        self.execution_meta.error(error_msg)

    def is_stopped_requested(self) -> bool:
        """
        Check if a stop has been requested for the agent state.

        Returns:
            bool: True if stop requested, False otherwise.
        """
        stopped = self.execution_meta.is_stopped_requested()
        logger.debug("State is_stopped_requested: %s", stopped)
        return stopped
Attributes
context class-attribute instance-attribute
context = Field(default_factory=list)
context_summary class-attribute instance-attribute
context_summary = None
execution_meta class-attribute instance-attribute
execution_meta = Field(default_factory=lambda: ExecutionState(current_node=START))
Functions
advance_step
advance_step()

Advance the execution step in the metadata.

Source code in pyagenity/state/agent_state.py
90
91
92
93
94
95
96
def advance_step(self) -> None:
    """
    Advance the execution step in the metadata.
    """
    old_step = self.execution_meta.step
    self.execution_meta.advance_step()
    logger.debug("Advanced step from %d to %d", old_step, self.execution_meta.step)
clear_interrupt
clear_interrupt()

Clear any interrupt in the execution metadata.

Source code in pyagenity/state/agent_state.py
61
62
63
64
65
66
def clear_interrupt(self) -> None:
    """
    Clear any interrupt in the execution metadata.
    """
    logger.debug("Clearing interrupt")
    self.execution_meta.clear_interrupt()
complete
complete()

Mark the agent state as completed.

Source code in pyagenity/state/agent_state.py
109
110
111
112
113
114
def complete(self) -> None:
    """
    Mark the agent state as completed.
    """
    logger.info("Marking state as completed")
    self.execution_meta.complete()
error
error(error_msg)

Mark the agent state as errored.

Parameters:

Name Type Description Default
error_msg str

Error message to record.

required
Source code in pyagenity/state/agent_state.py
116
117
118
119
120
121
122
123
124
def error(self, error_msg: str) -> None:
    """
    Mark the agent state as errored.

    Args:
        error_msg (str): Error message to record.
    """
    logger.error("Setting state error: %s", error_msg)
    self.execution_meta.error(error_msg)
is_interrupted
is_interrupted()

Check if the agent state is currently interrupted.

Returns:

Name Type Description
bool bool

True if interrupted, False otherwise.

Source code in pyagenity/state/agent_state.py
79
80
81
82
83
84
85
86
87
88
def is_interrupted(self) -> bool:
    """
    Check if the agent state is currently interrupted.

    Returns:
        bool: True if interrupted, False otherwise.
    """
    interrupted = self.execution_meta.is_interrupted()
    logger.debug("State is_interrupted: %s", interrupted)
    return interrupted
is_running
is_running()

Check if the agent state is currently running.

Returns:

Name Type Description
bool bool

True if running, False otherwise.

Source code in pyagenity/state/agent_state.py
68
69
70
71
72
73
74
75
76
77
def is_running(self) -> bool:
    """
    Check if the agent state is currently running.

    Returns:
        bool: True if running, False otherwise.
    """
    running = self.execution_meta.is_running()
    logger.debug("State is_running: %s", running)
    return running
is_stopped_requested
is_stopped_requested()

Check if a stop has been requested for the agent state.

Returns:

Name Type Description
bool bool

True if stop requested, False otherwise.

Source code in pyagenity/state/agent_state.py
126
127
128
129
130
131
132
133
134
135
def is_stopped_requested(self) -> bool:
    """
    Check if a stop has been requested for the agent state.

    Returns:
        bool: True if stop requested, False otherwise.
    """
    stopped = self.execution_meta.is_stopped_requested()
    logger.debug("State is_stopped_requested: %s", stopped)
    return stopped
set_current_node
set_current_node(node)

Set the current node in the execution metadata.

Parameters:

Name Type Description Default
node str

Node to set as current.

required
Source code in pyagenity/state/agent_state.py
 98
 99
100
101
102
103
104
105
106
107
def set_current_node(self, node: str) -> None:
    """
    Set the current node in the execution metadata.

    Args:
        node (str): Node to set as current.
    """
    old_node = self.execution_meta.current_node
    self.execution_meta.set_current_node(node)
    logger.debug("Changed current node from '%s' to '%s'", old_node, node)
set_interrupt
set_interrupt(node, reason, status, data=None)

Set an interrupt in the execution metadata.

Parameters:

Name Type Description Default
node str

Node where the interrupt occurred.

required
reason str

Reason for the interrupt.

required
status

Execution status to set.

required
data dict | None

Optional additional interrupt data.

None
Source code in pyagenity/state/agent_state.py
48
49
50
51
52
53
54
55
56
57
58
59
def set_interrupt(self, node: str, reason: str, status, data: dict | None = None) -> None:
    """
    Set an interrupt in the execution metadata.

    Args:
        node (str): Node where the interrupt occurred.
        reason (str): Reason for the interrupt.
        status: Execution status to set.
        data (dict | None): Optional additional interrupt data.
    """
    logger.debug("Setting interrupt at node '%s' with reason: %s", node, reason)
    self.execution_meta.set_interrupt(node, reason, status, data)
BaseContextManager

Bases: ABC

Abstract base class for context management in AI interactions.

Subclasses should implement trim_context as either a synchronous or asynchronous method. Generic over AgentState or its subclasses.

Methods:

Name Description
atrim_context

Trim context based on message count asynchronously.

trim_context

Trim context based on message count. Can be sync or async.

Source code in pyagenity/state/base_context.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class BaseContextManager[S](ABC):
    """
    Abstract base class for context management in AI interactions.

    Subclasses should implement `trim_context` as either a synchronous or asynchronous method.
    Generic over AgentState or its subclasses.
    """

    @abstractmethod
    def trim_context(self, state: S) -> S:
        """
        Trim context based on message count. Can be sync or async.

        Subclasses may implement as either a synchronous or asynchronous method.

        Args:
            state: The state containing context to be trimmed.

        Returns:
            The state with trimmed context, either directly or as an awaitable.
        """
        raise NotImplementedError("Subclasses must implement this method (sync or async)")

    @abstractmethod
    async def atrim_context(self, state: S) -> S:
        """
        Trim context based on message count asynchronously.

        Args:
            state: The state containing context to be trimmed.

        Returns:
            The state with trimmed context.
        """
        raise NotImplementedError("Subclasses must implement this method")
Functions
atrim_context abstractmethod async
atrim_context(state)

Trim context based on message count asynchronously.

Parameters:

Name Type Description Default
state S

The state containing context to be trimmed.

required

Returns:

Type Description
S

The state with trimmed context.

Source code in pyagenity/state/base_context.py
43
44
45
46
47
48
49
50
51
52
53
54
@abstractmethod
async def atrim_context(self, state: S) -> S:
    """
    Trim context based on message count asynchronously.

    Args:
        state: The state containing context to be trimmed.

    Returns:
        The state with trimmed context.
    """
    raise NotImplementedError("Subclasses must implement this method")
trim_context abstractmethod
trim_context(state)

Trim context based on message count. Can be sync or async.

Subclasses may implement as either a synchronous or asynchronous method.

Parameters:

Name Type Description Default
state S

The state containing context to be trimmed.

required

Returns:

Type Description
S

The state with trimmed context, either directly or as an awaitable.

Source code in pyagenity/state/base_context.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
@abstractmethod
def trim_context(self, state: S) -> S:
    """
    Trim context based on message count. Can be sync or async.

    Subclasses may implement as either a synchronous or asynchronous method.

    Args:
        state: The state containing context to be trimmed.

    Returns:
        The state with trimmed context, either directly or as an awaitable.
    """
    raise NotImplementedError("Subclasses must implement this method (sync or async)")
ExecutionState

Bases: BaseModel

Tracks the internal execution state of a graph.

This class manages the execution progress, interrupt status, and internal data that should not be exposed to users.

Methods:

Name Description
advance_step

Advance to the next execution step.

clear_interrupt

Clear the interrupt state and resume execution.

complete

Mark execution as completed.

error

Mark execution as errored.

from_dict

Create an ExecutionState instance from a dictionary.

is_interrupted

Check if execution is currently interrupted.

is_running

Check if execution is currently running.

is_stopped_requested

Check if a stop has been requested for execution.

set_current_node

Update the current node in execution state.

set_interrupt

Set the interrupt state for execution.

Attributes:

Name Type Description
current_node str
internal_data dict[str, Any]
interrupt_data dict[str, Any] | None
interrupt_reason str | None
interrupted_node str | None
status ExecutionStatus
step int
stop_current_execution StopRequestStatus
thread_id str | None
Source code in pyagenity/state/execution_state.py
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
class ExecutionState(BaseModel):
    """
    Tracks the internal execution state of a graph.

    This class manages the execution progress, interrupt status, and internal
    data that should not be exposed to users.
    """

    # Core execution tracking
    current_node: str
    step: int = 0
    status: ExecutionStatus = ExecutionStatus.RUNNING

    # Interrupt management
    interrupted_node: str | None = None
    interrupt_reason: str | None = None
    interrupt_data: dict[str, Any] | None = None

    # Thread/session identification
    thread_id: str | None = None

    # Stop Current Execution Flag
    stop_current_execution: StopRequestStatus = StopRequestStatus.NONE

    # Internal execution data (hidden from user)
    internal_data: dict[str, Any] = Field(default_factory=dict)

    @classmethod
    def from_dict(cls, data: dict[str, Any]) -> "ExecutionState":
        """
        Create an ExecutionState instance from a dictionary.

        Args:
            data (dict[str, Any]): Dictionary containing execution state fields.

        Returns:
            ExecutionState: The deserialized execution state object.
        """
        return cls.model_validate(
            {
                "current_node": data["current_node"],
                "step": data.get("step", 0),
                "status": ExecutionStatus(data.get("status", "running")),
                "interrupted_node": data.get("interrupted_node"),
                "interrupt_reason": data.get("interrupt_reason"),
                "interrupt_data": data.get("interrupt_data"),
                "thread_id": data.get("thread_id"),
                "internal_data": data.get("_internal_data", {}),
            }
        )

    def set_interrupt(
        self, node: str, reason: str, status: ExecutionStatus, data: dict[str, Any] | None = None
    ) -> None:
        """
        Set the interrupt state for execution.

        Args:
            node (str): Node where the interrupt occurred.
            reason (str): Reason for the interrupt.
            status (ExecutionStatus): Status to set for the interrupt.
            data (dict[str, Any] | None): Optional additional interrupt data.
        """
        logger.debug(
            "Setting interrupt: node='%s', reason='%s', status='%s'",
            node,
            reason,
            status.value,
        )
        self.interrupted_node = node
        self.interrupt_reason = reason
        self.status = status
        self.interrupt_data = data

    def clear_interrupt(self) -> None:
        """
        Clear the interrupt state and resume execution.
        """
        logger.debug("Clearing interrupt, resuming execution")
        self.interrupted_node = None
        self.interrupt_reason = None
        self.interrupt_data = None
        self.status = ExecutionStatus.RUNNING

    def is_interrupted(self) -> bool:
        """
        Check if execution is currently interrupted.

        Returns:
            bool: True if interrupted, False otherwise.
        """
        interrupted = self.status in [
            ExecutionStatus.INTERRUPTED_BEFORE,
            ExecutionStatus.INTERRUPTED_AFTER,
        ]
        logger.debug("Execution is_interrupted: %s (status: %s)", interrupted, self.status.value)
        return interrupted

    def advance_step(self) -> None:
        """
        Advance to the next execution step.
        """
        old_step = self.step
        self.step += 1
        logger.debug("Advanced step from %d to %d", old_step, self.step)

    def set_current_node(self, node: str) -> None:
        """
        Update the current node in execution state.

        Args:
            node (str): Node to set as current.
        """
        old_node = self.current_node
        self.current_node = node
        logger.debug("Changed current node from '%s' to '%s'", old_node, node)

    def complete(self) -> None:
        """
        Mark execution as completed.
        """
        logger.info("Marking execution as completed")
        self.status = ExecutionStatus.COMPLETED

    def error(self, error_msg: str) -> None:
        """
        Mark execution as errored.

        Args:
            error_msg (str): Error message to record.
        """
        logger.error("Marking execution as errored: %s", error_msg)
        self.status = ExecutionStatus.ERROR
        self.internal_data["error"] = error_msg

    def is_running(self) -> bool:
        """
        Check if execution is currently running.

        Returns:
            bool: True if running, False otherwise.
        """
        running = self.status == ExecutionStatus.RUNNING
        logger.debug("Execution is_running: %s (status: %s)", running, self.status.value)
        return running

    def is_stopped_requested(self) -> bool:
        """
        Check if a stop has been requested for execution.

        Returns:
            bool: True if stop requested, False otherwise.
        """
        stopped = self.stop_current_execution == StopRequestStatus.STOP_REQUESTED
        logger.debug(
            "Execution is_stopped_requested: %s (stop_current_execution: %s)",
            stopped,
            self.stop_current_execution.value,
        )
        return stopped
Attributes
current_node instance-attribute
current_node
internal_data class-attribute instance-attribute
internal_data = Field(default_factory=dict)
interrupt_data class-attribute instance-attribute
interrupt_data = None
interrupt_reason class-attribute instance-attribute
interrupt_reason = None
interrupted_node class-attribute instance-attribute
interrupted_node = None
status class-attribute instance-attribute
status = RUNNING
step class-attribute instance-attribute
step = 0
stop_current_execution class-attribute instance-attribute
stop_current_execution = NONE
thread_id class-attribute instance-attribute
thread_id = None
Functions
advance_step
advance_step()

Advance to the next execution step.

Source code in pyagenity/state/execution_state.py
134
135
136
137
138
139
140
def advance_step(self) -> None:
    """
    Advance to the next execution step.
    """
    old_step = self.step
    self.step += 1
    logger.debug("Advanced step from %d to %d", old_step, self.step)
clear_interrupt
clear_interrupt()

Clear the interrupt state and resume execution.

Source code in pyagenity/state/execution_state.py
110
111
112
113
114
115
116
117
118
def clear_interrupt(self) -> None:
    """
    Clear the interrupt state and resume execution.
    """
    logger.debug("Clearing interrupt, resuming execution")
    self.interrupted_node = None
    self.interrupt_reason = None
    self.interrupt_data = None
    self.status = ExecutionStatus.RUNNING
complete
complete()

Mark execution as completed.

Source code in pyagenity/state/execution_state.py
153
154
155
156
157
158
def complete(self) -> None:
    """
    Mark execution as completed.
    """
    logger.info("Marking execution as completed")
    self.status = ExecutionStatus.COMPLETED
error
error(error_msg)

Mark execution as errored.

Parameters:

Name Type Description Default
error_msg str

Error message to record.

required
Source code in pyagenity/state/execution_state.py
160
161
162
163
164
165
166
167
168
169
def error(self, error_msg: str) -> None:
    """
    Mark execution as errored.

    Args:
        error_msg (str): Error message to record.
    """
    logger.error("Marking execution as errored: %s", error_msg)
    self.status = ExecutionStatus.ERROR
    self.internal_data["error"] = error_msg
from_dict classmethod
from_dict(data)

Create an ExecutionState instance from a dictionary.

Parameters:

Name Type Description Default
data dict[str, Any]

Dictionary containing execution state fields.

required

Returns:

Name Type Description
ExecutionState ExecutionState

The deserialized execution state object.

Source code in pyagenity/state/execution_state.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "ExecutionState":
    """
    Create an ExecutionState instance from a dictionary.

    Args:
        data (dict[str, Any]): Dictionary containing execution state fields.

    Returns:
        ExecutionState: The deserialized execution state object.
    """
    return cls.model_validate(
        {
            "current_node": data["current_node"],
            "step": data.get("step", 0),
            "status": ExecutionStatus(data.get("status", "running")),
            "interrupted_node": data.get("interrupted_node"),
            "interrupt_reason": data.get("interrupt_reason"),
            "interrupt_data": data.get("interrupt_data"),
            "thread_id": data.get("thread_id"),
            "internal_data": data.get("_internal_data", {}),
        }
    )
is_interrupted
is_interrupted()

Check if execution is currently interrupted.

Returns:

Name Type Description
bool bool

True if interrupted, False otherwise.

Source code in pyagenity/state/execution_state.py
120
121
122
123
124
125
126
127
128
129
130
131
132
def is_interrupted(self) -> bool:
    """
    Check if execution is currently interrupted.

    Returns:
        bool: True if interrupted, False otherwise.
    """
    interrupted = self.status in [
        ExecutionStatus.INTERRUPTED_BEFORE,
        ExecutionStatus.INTERRUPTED_AFTER,
    ]
    logger.debug("Execution is_interrupted: %s (status: %s)", interrupted, self.status.value)
    return interrupted
is_running
is_running()

Check if execution is currently running.

Returns:

Name Type Description
bool bool

True if running, False otherwise.

Source code in pyagenity/state/execution_state.py
171
172
173
174
175
176
177
178
179
180
def is_running(self) -> bool:
    """
    Check if execution is currently running.

    Returns:
        bool: True if running, False otherwise.
    """
    running = self.status == ExecutionStatus.RUNNING
    logger.debug("Execution is_running: %s (status: %s)", running, self.status.value)
    return running
is_stopped_requested
is_stopped_requested()

Check if a stop has been requested for execution.

Returns:

Name Type Description
bool bool

True if stop requested, False otherwise.

Source code in pyagenity/state/execution_state.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def is_stopped_requested(self) -> bool:
    """
    Check if a stop has been requested for execution.

    Returns:
        bool: True if stop requested, False otherwise.
    """
    stopped = self.stop_current_execution == StopRequestStatus.STOP_REQUESTED
    logger.debug(
        "Execution is_stopped_requested: %s (stop_current_execution: %s)",
        stopped,
        self.stop_current_execution.value,
    )
    return stopped
set_current_node
set_current_node(node)

Update the current node in execution state.

Parameters:

Name Type Description Default
node str

Node to set as current.

required
Source code in pyagenity/state/execution_state.py
142
143
144
145
146
147
148
149
150
151
def set_current_node(self, node: str) -> None:
    """
    Update the current node in execution state.

    Args:
        node (str): Node to set as current.
    """
    old_node = self.current_node
    self.current_node = node
    logger.debug("Changed current node from '%s' to '%s'", old_node, node)
set_interrupt
set_interrupt(node, reason, status, data=None)

Set the interrupt state for execution.

Parameters:

Name Type Description Default
node str

Node where the interrupt occurred.

required
reason str

Reason for the interrupt.

required
status ExecutionStatus

Status to set for the interrupt.

required
data dict[str, Any] | None

Optional additional interrupt data.

None
Source code in pyagenity/state/execution_state.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def set_interrupt(
    self, node: str, reason: str, status: ExecutionStatus, data: dict[str, Any] | None = None
) -> None:
    """
    Set the interrupt state for execution.

    Args:
        node (str): Node where the interrupt occurred.
        reason (str): Reason for the interrupt.
        status (ExecutionStatus): Status to set for the interrupt.
        data (dict[str, Any] | None): Optional additional interrupt data.
    """
    logger.debug(
        "Setting interrupt: node='%s', reason='%s', status='%s'",
        node,
        reason,
        status.value,
    )
    self.interrupted_node = node
    self.interrupt_reason = reason
    self.status = status
    self.interrupt_data = data
ExecutionStatus

Bases: Enum

Status of graph execution.

Attributes:

Name Type Description
COMPLETED
ERROR
INTERRUPTED_AFTER
INTERRUPTED_BEFORE
RUNNING
Source code in pyagenity/state/execution_state.py
18
19
20
21
22
23
24
25
class ExecutionStatus(Enum):
    """Status of graph execution."""

    RUNNING = "running"
    INTERRUPTED_BEFORE = "interrupted_before"
    INTERRUPTED_AFTER = "interrupted_after"
    COMPLETED = "completed"
    ERROR = "error"
Attributes
COMPLETED class-attribute instance-attribute
COMPLETED = 'completed'
ERROR class-attribute instance-attribute
ERROR = 'error'
INTERRUPTED_AFTER class-attribute instance-attribute
INTERRUPTED_AFTER = 'interrupted_after'
INTERRUPTED_BEFORE class-attribute instance-attribute
INTERRUPTED_BEFORE = 'interrupted_before'
RUNNING class-attribute instance-attribute
RUNNING = 'running'
MessageContextManager

Bases: BaseContextManager[S]

Manages the context field for AI interactions.

This class trims the context (message history) based on a maximum number of user messages, ensuring the first message (usually a system prompt) is always preserved. Generic over AgentState or its subclasses.

Methods:

Name Description
__init__

Initialize the MessageContextManager.

atrim_context

Asynchronous version of trim_context.

trim_context

Trim the context in the given AgentState based on the maximum number of user messages.

Attributes:

Name Type Description
max_messages
Source code in pyagenity/state/message_context_manager.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
class MessageContextManager(BaseContextManager[S]):
    """
    Manages the context field for AI interactions.

    This class trims the context (message history) based on a maximum number of user messages,
    ensuring the first message (usually a system prompt) is always preserved.
    Generic over AgentState or its subclasses.
    """

    def __init__(self, max_messages: int = 10) -> None:
        """
        Initialize the MessageContextManager.

        Args:
            max_messages (int): Maximum number of
                user messages to keep in context. Default is 10.
        """
        self.max_messages = max_messages
        logger.debug("Initialized MessageContextManager with max_messages=%d", max_messages)

    def _trim(self, messages: list[Message]) -> list[Message] | None:
        """
        Trim messages keeping system messages and most recent user messages.

        Returns None if no trimming is needed, otherwise returns the trimmed list.
        """
        # check context is empty
        if not messages:
            logger.debug("No messages to trim; context is empty")
            return None

        # Count user messages
        user_message_count = sum(1 for msg in messages if msg.role == "user")

        if user_message_count <= self.max_messages:
            # no trimming needed
            logger.debug(
                "No trimming needed; context is within limits (%d user messages)",
                user_message_count,
            )
            return None

        # Separate system messages (usually at the beginning)
        system_messages = [msg for msg in messages if msg.role == "system"]
        non_system_messages = [msg for msg in messages if msg.role != "system"]

        # Keep only the most recent messages that include max_messages user messages
        final_non_system = []
        user_count = 0

        # Iterate from the end to keep most recent messages
        for msg in reversed(non_system_messages):
            if msg.role == "user":
                if user_count >= self.max_messages:
                    break
                user_count += 1
            final_non_system.insert(0, msg)  # Insert at beginning to maintain order

        # Combine system messages (at start) with trimmed conversation
        trimmed_messages = system_messages + final_non_system

        logger.debug(
            "Trimmed from %d to %d messages (%d user messages kept)",
            len(messages),
            len(trimmed_messages),
            user_count,
        )

        return trimmed_messages

    def trim_context(self, state: S) -> S:
        """
        Trim the context in the given AgentState based on the maximum number of user messages.

        The first message (typically a system prompt) is always preserved. Only the most recent
        user messages up to `max_messages` are kept, along with the first message.

        Args:
            state (AgentState): The agent state containing the context to trim.

        Returns:
            S: The updated agent state with trimmed context.
        """
        messages = state.context
        trimmed_messages = self._trim(messages)
        if trimmed_messages is not None:
            state.context = trimmed_messages
        return state

    async def atrim_context(self, state: S) -> S:
        """
        Asynchronous version of trim_context.

        Args:
            state (AgentState): The agent state containing the context to trim.

        Returns:
            S: The updated agent state with trimmed context.
        """
        messages = state.context
        trimmed_messages = self._trim(messages)
        if trimmed_messages is not None:
            state.context = trimmed_messages
        return state
Attributes
max_messages instance-attribute
max_messages = max_messages
Functions
__init__
__init__(max_messages=10)

Initialize the MessageContextManager.

Parameters:

Name Type Description Default
max_messages int

Maximum number of user messages to keep in context. Default is 10.

10
Source code in pyagenity/state/message_context_manager.py
31
32
33
34
35
36
37
38
39
40
def __init__(self, max_messages: int = 10) -> None:
    """
    Initialize the MessageContextManager.

    Args:
        max_messages (int): Maximum number of
            user messages to keep in context. Default is 10.
    """
    self.max_messages = max_messages
    logger.debug("Initialized MessageContextManager with max_messages=%d", max_messages)
atrim_context async
atrim_context(state)

Asynchronous version of trim_context.

Parameters:

Name Type Description Default
state AgentState

The agent state containing the context to trim.

required

Returns:

Name Type Description
S S

The updated agent state with trimmed context.

Source code in pyagenity/state/message_context_manager.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
async def atrim_context(self, state: S) -> S:
    """
    Asynchronous version of trim_context.

    Args:
        state (AgentState): The agent state containing the context to trim.

    Returns:
        S: The updated agent state with trimmed context.
    """
    messages = state.context
    trimmed_messages = self._trim(messages)
    if trimmed_messages is not None:
        state.context = trimmed_messages
    return state
trim_context
trim_context(state)

Trim the context in the given AgentState based on the maximum number of user messages.

The first message (typically a system prompt) is always preserved. Only the most recent user messages up to max_messages are kept, along with the first message.

Parameters:

Name Type Description Default
state AgentState

The agent state containing the context to trim.

required

Returns:

Name Type Description
S S

The updated agent state with trimmed context.

Source code in pyagenity/state/message_context_manager.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def trim_context(self, state: S) -> S:
    """
    Trim the context in the given AgentState based on the maximum number of user messages.

    The first message (typically a system prompt) is always preserved. Only the most recent
    user messages up to `max_messages` are kept, along with the first message.

    Args:
        state (AgentState): The agent state containing the context to trim.

    Returns:
        S: The updated agent state with trimmed context.
    """
    messages = state.context
    trimmed_messages = self._trim(messages)
    if trimmed_messages is not None:
        state.context = trimmed_messages
    return state

Modules

agent_state

Agent state schema for PyAgenity agent graphs.

This module provides the AgentState class, which tracks message context, context summaries, and internal execution metadata for agent workflows. Supports subclassing for custom application fields.

Classes:

Name Description
AgentState

Common state schema that includes messages, context and internal execution metadata.

Attributes:

Name Type Description
logger
Attributes
logger module-attribute
logger = getLogger(__name__)
Classes
AgentState

Bases: BaseModel

Common state schema that includes messages, context and internal execution metadata.

This class can be subclassed to add application-specific fields while maintaining compatibility with the PyAgenity framework. All internal execution metadata is preserved through subclassing.

Notes: - execution_meta contains internal-only execution progress and interrupt info. - Users may subclass AgentState to add application fields; internal exec meta remains available to the runtime and will be persisted with the state. - When subclassing, add your fields but keep the core fields intact.

Example

class MyCustomState(AgentState): user_data: dict = Field(default_factory=dict) custom_field: str = "default"

Methods:

Name Description
advance_step

Advance the execution step in the metadata.

clear_interrupt

Clear any interrupt in the execution metadata.

complete

Mark the agent state as completed.

error

Mark the agent state as errored.

is_interrupted

Check if the agent state is currently interrupted.

is_running

Check if the agent state is currently running.

is_stopped_requested

Check if a stop has been requested for the agent state.

set_current_node

Set the current node in the execution metadata.

set_interrupt

Set an interrupt in the execution metadata.

Attributes:

Name Type Description
context Annotated[list[Message], add_messages]
context_summary str | None
execution_meta ExecutionState
Source code in pyagenity/state/agent_state.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
class AgentState(BaseModel):
    """Common state schema that includes messages, context and internal execution metadata.

    This class can be subclassed to add application-specific fields while maintaining
    compatibility with the PyAgenity framework. All internal execution metadata
    is preserved through subclassing.

    Notes:
    - `execution_meta` contains internal-only execution progress and interrupt info.
    - Users may subclass `AgentState` to add application fields; internal exec meta remains
      available to the runtime and will be persisted with the state.
    - When subclassing, add your fields but keep the core fields intact.

    Example:
        class MyCustomState(AgentState):
            user_data: dict = Field(default_factory=dict)
            custom_field: str = "default"
    """

    context: Annotated[list[Message], add_messages] = Field(default_factory=list)
    context_summary: str | None = None
    # Internal execution metadata (kept private-ish but accessible to runtime)
    execution_meta: ExecMeta = Field(default_factory=lambda: ExecMeta(current_node=START))

    # Convenience delegation methods for execution meta so callers can use the same API
    def set_interrupt(self, node: str, reason: str, status, data: dict | None = None) -> None:
        """
        Set an interrupt in the execution metadata.

        Args:
            node (str): Node where the interrupt occurred.
            reason (str): Reason for the interrupt.
            status: Execution status to set.
            data (dict | None): Optional additional interrupt data.
        """
        logger.debug("Setting interrupt at node '%s' with reason: %s", node, reason)
        self.execution_meta.set_interrupt(node, reason, status, data)

    def clear_interrupt(self) -> None:
        """
        Clear any interrupt in the execution metadata.
        """
        logger.debug("Clearing interrupt")
        self.execution_meta.clear_interrupt()

    def is_running(self) -> bool:
        """
        Check if the agent state is currently running.

        Returns:
            bool: True if running, False otherwise.
        """
        running = self.execution_meta.is_running()
        logger.debug("State is_running: %s", running)
        return running

    def is_interrupted(self) -> bool:
        """
        Check if the agent state is currently interrupted.

        Returns:
            bool: True if interrupted, False otherwise.
        """
        interrupted = self.execution_meta.is_interrupted()
        logger.debug("State is_interrupted: %s", interrupted)
        return interrupted

    def advance_step(self) -> None:
        """
        Advance the execution step in the metadata.
        """
        old_step = self.execution_meta.step
        self.execution_meta.advance_step()
        logger.debug("Advanced step from %d to %d", old_step, self.execution_meta.step)

    def set_current_node(self, node: str) -> None:
        """
        Set the current node in the execution metadata.

        Args:
            node (str): Node to set as current.
        """
        old_node = self.execution_meta.current_node
        self.execution_meta.set_current_node(node)
        logger.debug("Changed current node from '%s' to '%s'", old_node, node)

    def complete(self) -> None:
        """
        Mark the agent state as completed.
        """
        logger.info("Marking state as completed")
        self.execution_meta.complete()

    def error(self, error_msg: str) -> None:
        """
        Mark the agent state as errored.

        Args:
            error_msg (str): Error message to record.
        """
        logger.error("Setting state error: %s", error_msg)
        self.execution_meta.error(error_msg)

    def is_stopped_requested(self) -> bool:
        """
        Check if a stop has been requested for the agent state.

        Returns:
            bool: True if stop requested, False otherwise.
        """
        stopped = self.execution_meta.is_stopped_requested()
        logger.debug("State is_stopped_requested: %s", stopped)
        return stopped
Attributes
context class-attribute instance-attribute
context = Field(default_factory=list)
context_summary class-attribute instance-attribute
context_summary = None
execution_meta class-attribute instance-attribute
execution_meta = Field(default_factory=lambda: ExecutionState(current_node=START))
Functions
advance_step
advance_step()

Advance the execution step in the metadata.

Source code in pyagenity/state/agent_state.py
90
91
92
93
94
95
96
def advance_step(self) -> None:
    """
    Advance the execution step in the metadata.
    """
    old_step = self.execution_meta.step
    self.execution_meta.advance_step()
    logger.debug("Advanced step from %d to %d", old_step, self.execution_meta.step)
clear_interrupt
clear_interrupt()

Clear any interrupt in the execution metadata.

Source code in pyagenity/state/agent_state.py
61
62
63
64
65
66
def clear_interrupt(self) -> None:
    """
    Clear any interrupt in the execution metadata.
    """
    logger.debug("Clearing interrupt")
    self.execution_meta.clear_interrupt()
complete
complete()

Mark the agent state as completed.

Source code in pyagenity/state/agent_state.py
109
110
111
112
113
114
def complete(self) -> None:
    """
    Mark the agent state as completed.
    """
    logger.info("Marking state as completed")
    self.execution_meta.complete()
error
error(error_msg)

Mark the agent state as errored.

Parameters:

Name Type Description Default
error_msg str

Error message to record.

required
Source code in pyagenity/state/agent_state.py
116
117
118
119
120
121
122
123
124
def error(self, error_msg: str) -> None:
    """
    Mark the agent state as errored.

    Args:
        error_msg (str): Error message to record.
    """
    logger.error("Setting state error: %s", error_msg)
    self.execution_meta.error(error_msg)
is_interrupted
is_interrupted()

Check if the agent state is currently interrupted.

Returns:

Name Type Description
bool bool

True if interrupted, False otherwise.

Source code in pyagenity/state/agent_state.py
79
80
81
82
83
84
85
86
87
88
def is_interrupted(self) -> bool:
    """
    Check if the agent state is currently interrupted.

    Returns:
        bool: True if interrupted, False otherwise.
    """
    interrupted = self.execution_meta.is_interrupted()
    logger.debug("State is_interrupted: %s", interrupted)
    return interrupted
is_running
is_running()

Check if the agent state is currently running.

Returns:

Name Type Description
bool bool

True if running, False otherwise.

Source code in pyagenity/state/agent_state.py
68
69
70
71
72
73
74
75
76
77
def is_running(self) -> bool:
    """
    Check if the agent state is currently running.

    Returns:
        bool: True if running, False otherwise.
    """
    running = self.execution_meta.is_running()
    logger.debug("State is_running: %s", running)
    return running
is_stopped_requested
is_stopped_requested()

Check if a stop has been requested for the agent state.

Returns:

Name Type Description
bool bool

True if stop requested, False otherwise.

Source code in pyagenity/state/agent_state.py
126
127
128
129
130
131
132
133
134
135
def is_stopped_requested(self) -> bool:
    """
    Check if a stop has been requested for the agent state.

    Returns:
        bool: True if stop requested, False otherwise.
    """
    stopped = self.execution_meta.is_stopped_requested()
    logger.debug("State is_stopped_requested: %s", stopped)
    return stopped
set_current_node
set_current_node(node)

Set the current node in the execution metadata.

Parameters:

Name Type Description Default
node str

Node to set as current.

required
Source code in pyagenity/state/agent_state.py
 98
 99
100
101
102
103
104
105
106
107
def set_current_node(self, node: str) -> None:
    """
    Set the current node in the execution metadata.

    Args:
        node (str): Node to set as current.
    """
    old_node = self.execution_meta.current_node
    self.execution_meta.set_current_node(node)
    logger.debug("Changed current node from '%s' to '%s'", old_node, node)
set_interrupt
set_interrupt(node, reason, status, data=None)

Set an interrupt in the execution metadata.

Parameters:

Name Type Description Default
node str

Node where the interrupt occurred.

required
reason str

Reason for the interrupt.

required
status

Execution status to set.

required
data dict | None

Optional additional interrupt data.

None
Source code in pyagenity/state/agent_state.py
48
49
50
51
52
53
54
55
56
57
58
59
def set_interrupt(self, node: str, reason: str, status, data: dict | None = None) -> None:
    """
    Set an interrupt in the execution metadata.

    Args:
        node (str): Node where the interrupt occurred.
        reason (str): Reason for the interrupt.
        status: Execution status to set.
        data (dict | None): Optional additional interrupt data.
    """
    logger.debug("Setting interrupt at node '%s' with reason: %s", node, reason)
    self.execution_meta.set_interrupt(node, reason, status, data)
Functions
base_context

Abstract base class for context management in PyAgenity agent graphs.

This module provides BaseContextManager, which defines the interface for trimming and managing message context in agent state objects.

Classes:

Name Description
BaseContextManager

Abstract base class for context management in AI interactions.

Attributes:

Name Type Description
S
logger
Attributes
S module-attribute
S = TypeVar('S', bound=AgentState)
logger module-attribute
logger = getLogger(__name__)
Classes
BaseContextManager

Bases: ABC

Abstract base class for context management in AI interactions.

Subclasses should implement trim_context as either a synchronous or asynchronous method. Generic over AgentState or its subclasses.

Methods:

Name Description
atrim_context

Trim context based on message count asynchronously.

trim_context

Trim context based on message count. Can be sync or async.

Source code in pyagenity/state/base_context.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class BaseContextManager[S](ABC):
    """
    Abstract base class for context management in AI interactions.

    Subclasses should implement `trim_context` as either a synchronous or asynchronous method.
    Generic over AgentState or its subclasses.
    """

    @abstractmethod
    def trim_context(self, state: S) -> S:
        """
        Trim context based on message count. Can be sync or async.

        Subclasses may implement as either a synchronous or asynchronous method.

        Args:
            state: The state containing context to be trimmed.

        Returns:
            The state with trimmed context, either directly or as an awaitable.
        """
        raise NotImplementedError("Subclasses must implement this method (sync or async)")

    @abstractmethod
    async def atrim_context(self, state: S) -> S:
        """
        Trim context based on message count asynchronously.

        Args:
            state: The state containing context to be trimmed.

        Returns:
            The state with trimmed context.
        """
        raise NotImplementedError("Subclasses must implement this method")
Functions
atrim_context abstractmethod async
atrim_context(state)

Trim context based on message count asynchronously.

Parameters:

Name Type Description Default
state S

The state containing context to be trimmed.

required

Returns:

Type Description
S

The state with trimmed context.

Source code in pyagenity/state/base_context.py
43
44
45
46
47
48
49
50
51
52
53
54
@abstractmethod
async def atrim_context(self, state: S) -> S:
    """
    Trim context based on message count asynchronously.

    Args:
        state: The state containing context to be trimmed.

    Returns:
        The state with trimmed context.
    """
    raise NotImplementedError("Subclasses must implement this method")
trim_context abstractmethod
trim_context(state)

Trim context based on message count. Can be sync or async.

Subclasses may implement as either a synchronous or asynchronous method.

Parameters:

Name Type Description Default
state S

The state containing context to be trimmed.

required

Returns:

Type Description
S

The state with trimmed context, either directly or as an awaitable.

Source code in pyagenity/state/base_context.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
@abstractmethod
def trim_context(self, state: S) -> S:
    """
    Trim context based on message count. Can be sync or async.

    Subclasses may implement as either a synchronous or asynchronous method.

    Args:
        state: The state containing context to be trimmed.

    Returns:
        The state with trimmed context, either directly or as an awaitable.
    """
    raise NotImplementedError("Subclasses must implement this method (sync or async)")
execution_state

Execution state management for graph execution in PyAgenity.

This module provides the ExecutionState class and related enums to track progress, interruptions, and pause/resume functionality for agent graph execution.

Classes:

Name Description
ExecutionState

Tracks the internal execution state of a graph.

ExecutionStatus

Status of graph execution.

StopRequestStatus

Status of graph execution.

Attributes:

Name Type Description
logger
Attributes
logger module-attribute
logger = getLogger(__name__)
Classes
ExecutionState

Bases: BaseModel

Tracks the internal execution state of a graph.

This class manages the execution progress, interrupt status, and internal data that should not be exposed to users.

Methods:

Name Description
advance_step

Advance to the next execution step.

clear_interrupt

Clear the interrupt state and resume execution.

complete

Mark execution as completed.

error

Mark execution as errored.

from_dict

Create an ExecutionState instance from a dictionary.

is_interrupted

Check if execution is currently interrupted.

is_running

Check if execution is currently running.

is_stopped_requested

Check if a stop has been requested for execution.

set_current_node

Update the current node in execution state.

set_interrupt

Set the interrupt state for execution.

Attributes:

Name Type Description
current_node str
internal_data dict[str, Any]
interrupt_data dict[str, Any] | None
interrupt_reason str | None
interrupted_node str | None
status ExecutionStatus
step int
stop_current_execution StopRequestStatus
thread_id str | None
Source code in pyagenity/state/execution_state.py
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
class ExecutionState(BaseModel):
    """
    Tracks the internal execution state of a graph.

    This class manages the execution progress, interrupt status, and internal
    data that should not be exposed to users.
    """

    # Core execution tracking
    current_node: str
    step: int = 0
    status: ExecutionStatus = ExecutionStatus.RUNNING

    # Interrupt management
    interrupted_node: str | None = None
    interrupt_reason: str | None = None
    interrupt_data: dict[str, Any] | None = None

    # Thread/session identification
    thread_id: str | None = None

    # Stop Current Execution Flag
    stop_current_execution: StopRequestStatus = StopRequestStatus.NONE

    # Internal execution data (hidden from user)
    internal_data: dict[str, Any] = Field(default_factory=dict)

    @classmethod
    def from_dict(cls, data: dict[str, Any]) -> "ExecutionState":
        """
        Create an ExecutionState instance from a dictionary.

        Args:
            data (dict[str, Any]): Dictionary containing execution state fields.

        Returns:
            ExecutionState: The deserialized execution state object.
        """
        return cls.model_validate(
            {
                "current_node": data["current_node"],
                "step": data.get("step", 0),
                "status": ExecutionStatus(data.get("status", "running")),
                "interrupted_node": data.get("interrupted_node"),
                "interrupt_reason": data.get("interrupt_reason"),
                "interrupt_data": data.get("interrupt_data"),
                "thread_id": data.get("thread_id"),
                "internal_data": data.get("_internal_data", {}),
            }
        )

    def set_interrupt(
        self, node: str, reason: str, status: ExecutionStatus, data: dict[str, Any] | None = None
    ) -> None:
        """
        Set the interrupt state for execution.

        Args:
            node (str): Node where the interrupt occurred.
            reason (str): Reason for the interrupt.
            status (ExecutionStatus): Status to set for the interrupt.
            data (dict[str, Any] | None): Optional additional interrupt data.
        """
        logger.debug(
            "Setting interrupt: node='%s', reason='%s', status='%s'",
            node,
            reason,
            status.value,
        )
        self.interrupted_node = node
        self.interrupt_reason = reason
        self.status = status
        self.interrupt_data = data

    def clear_interrupt(self) -> None:
        """
        Clear the interrupt state and resume execution.
        """
        logger.debug("Clearing interrupt, resuming execution")
        self.interrupted_node = None
        self.interrupt_reason = None
        self.interrupt_data = None
        self.status = ExecutionStatus.RUNNING

    def is_interrupted(self) -> bool:
        """
        Check if execution is currently interrupted.

        Returns:
            bool: True if interrupted, False otherwise.
        """
        interrupted = self.status in [
            ExecutionStatus.INTERRUPTED_BEFORE,
            ExecutionStatus.INTERRUPTED_AFTER,
        ]
        logger.debug("Execution is_interrupted: %s (status: %s)", interrupted, self.status.value)
        return interrupted

    def advance_step(self) -> None:
        """
        Advance to the next execution step.
        """
        old_step = self.step
        self.step += 1
        logger.debug("Advanced step from %d to %d", old_step, self.step)

    def set_current_node(self, node: str) -> None:
        """
        Update the current node in execution state.

        Args:
            node (str): Node to set as current.
        """
        old_node = self.current_node
        self.current_node = node
        logger.debug("Changed current node from '%s' to '%s'", old_node, node)

    def complete(self) -> None:
        """
        Mark execution as completed.
        """
        logger.info("Marking execution as completed")
        self.status = ExecutionStatus.COMPLETED

    def error(self, error_msg: str) -> None:
        """
        Mark execution as errored.

        Args:
            error_msg (str): Error message to record.
        """
        logger.error("Marking execution as errored: %s", error_msg)
        self.status = ExecutionStatus.ERROR
        self.internal_data["error"] = error_msg

    def is_running(self) -> bool:
        """
        Check if execution is currently running.

        Returns:
            bool: True if running, False otherwise.
        """
        running = self.status == ExecutionStatus.RUNNING
        logger.debug("Execution is_running: %s (status: %s)", running, self.status.value)
        return running

    def is_stopped_requested(self) -> bool:
        """
        Check if a stop has been requested for execution.

        Returns:
            bool: True if stop requested, False otherwise.
        """
        stopped = self.stop_current_execution == StopRequestStatus.STOP_REQUESTED
        logger.debug(
            "Execution is_stopped_requested: %s (stop_current_execution: %s)",
            stopped,
            self.stop_current_execution.value,
        )
        return stopped
Attributes
current_node instance-attribute
current_node
internal_data class-attribute instance-attribute
internal_data = Field(default_factory=dict)
interrupt_data class-attribute instance-attribute
interrupt_data = None
interrupt_reason class-attribute instance-attribute
interrupt_reason = None
interrupted_node class-attribute instance-attribute
interrupted_node = None
status class-attribute instance-attribute
status = RUNNING
step class-attribute instance-attribute
step = 0
stop_current_execution class-attribute instance-attribute
stop_current_execution = NONE
thread_id class-attribute instance-attribute
thread_id = None
Functions
advance_step
advance_step()

Advance to the next execution step.

Source code in pyagenity/state/execution_state.py
134
135
136
137
138
139
140
def advance_step(self) -> None:
    """
    Advance to the next execution step.
    """
    old_step = self.step
    self.step += 1
    logger.debug("Advanced step from %d to %d", old_step, self.step)
clear_interrupt
clear_interrupt()

Clear the interrupt state and resume execution.

Source code in pyagenity/state/execution_state.py
110
111
112
113
114
115
116
117
118
def clear_interrupt(self) -> None:
    """
    Clear the interrupt state and resume execution.
    """
    logger.debug("Clearing interrupt, resuming execution")
    self.interrupted_node = None
    self.interrupt_reason = None
    self.interrupt_data = None
    self.status = ExecutionStatus.RUNNING
complete
complete()

Mark execution as completed.

Source code in pyagenity/state/execution_state.py
153
154
155
156
157
158
def complete(self) -> None:
    """
    Mark execution as completed.
    """
    logger.info("Marking execution as completed")
    self.status = ExecutionStatus.COMPLETED
error
error(error_msg)

Mark execution as errored.

Parameters:

Name Type Description Default
error_msg str

Error message to record.

required
Source code in pyagenity/state/execution_state.py
160
161
162
163
164
165
166
167
168
169
def error(self, error_msg: str) -> None:
    """
    Mark execution as errored.

    Args:
        error_msg (str): Error message to record.
    """
    logger.error("Marking execution as errored: %s", error_msg)
    self.status = ExecutionStatus.ERROR
    self.internal_data["error"] = error_msg
from_dict classmethod
from_dict(data)

Create an ExecutionState instance from a dictionary.

Parameters:

Name Type Description Default
data dict[str, Any]

Dictionary containing execution state fields.

required

Returns:

Name Type Description
ExecutionState ExecutionState

The deserialized execution state object.

Source code in pyagenity/state/execution_state.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "ExecutionState":
    """
    Create an ExecutionState instance from a dictionary.

    Args:
        data (dict[str, Any]): Dictionary containing execution state fields.

    Returns:
        ExecutionState: The deserialized execution state object.
    """
    return cls.model_validate(
        {
            "current_node": data["current_node"],
            "step": data.get("step", 0),
            "status": ExecutionStatus(data.get("status", "running")),
            "interrupted_node": data.get("interrupted_node"),
            "interrupt_reason": data.get("interrupt_reason"),
            "interrupt_data": data.get("interrupt_data"),
            "thread_id": data.get("thread_id"),
            "internal_data": data.get("_internal_data", {}),
        }
    )
is_interrupted
is_interrupted()

Check if execution is currently interrupted.

Returns:

Name Type Description
bool bool

True if interrupted, False otherwise.

Source code in pyagenity/state/execution_state.py
120
121
122
123
124
125
126
127
128
129
130
131
132
def is_interrupted(self) -> bool:
    """
    Check if execution is currently interrupted.

    Returns:
        bool: True if interrupted, False otherwise.
    """
    interrupted = self.status in [
        ExecutionStatus.INTERRUPTED_BEFORE,
        ExecutionStatus.INTERRUPTED_AFTER,
    ]
    logger.debug("Execution is_interrupted: %s (status: %s)", interrupted, self.status.value)
    return interrupted
is_running
is_running()

Check if execution is currently running.

Returns:

Name Type Description
bool bool

True if running, False otherwise.

Source code in pyagenity/state/execution_state.py
171
172
173
174
175
176
177
178
179
180
def is_running(self) -> bool:
    """
    Check if execution is currently running.

    Returns:
        bool: True if running, False otherwise.
    """
    running = self.status == ExecutionStatus.RUNNING
    logger.debug("Execution is_running: %s (status: %s)", running, self.status.value)
    return running
is_stopped_requested
is_stopped_requested()

Check if a stop has been requested for execution.

Returns:

Name Type Description
bool bool

True if stop requested, False otherwise.

Source code in pyagenity/state/execution_state.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def is_stopped_requested(self) -> bool:
    """
    Check if a stop has been requested for execution.

    Returns:
        bool: True if stop requested, False otherwise.
    """
    stopped = self.stop_current_execution == StopRequestStatus.STOP_REQUESTED
    logger.debug(
        "Execution is_stopped_requested: %s (stop_current_execution: %s)",
        stopped,
        self.stop_current_execution.value,
    )
    return stopped
set_current_node
set_current_node(node)

Update the current node in execution state.

Parameters:

Name Type Description Default
node str

Node to set as current.

required
Source code in pyagenity/state/execution_state.py
142
143
144
145
146
147
148
149
150
151
def set_current_node(self, node: str) -> None:
    """
    Update the current node in execution state.

    Args:
        node (str): Node to set as current.
    """
    old_node = self.current_node
    self.current_node = node
    logger.debug("Changed current node from '%s' to '%s'", old_node, node)
set_interrupt
set_interrupt(node, reason, status, data=None)

Set the interrupt state for execution.

Parameters:

Name Type Description Default
node str

Node where the interrupt occurred.

required
reason str

Reason for the interrupt.

required
status ExecutionStatus

Status to set for the interrupt.

required
data dict[str, Any] | None

Optional additional interrupt data.

None
Source code in pyagenity/state/execution_state.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def set_interrupt(
    self, node: str, reason: str, status: ExecutionStatus, data: dict[str, Any] | None = None
) -> None:
    """
    Set the interrupt state for execution.

    Args:
        node (str): Node where the interrupt occurred.
        reason (str): Reason for the interrupt.
        status (ExecutionStatus): Status to set for the interrupt.
        data (dict[str, Any] | None): Optional additional interrupt data.
    """
    logger.debug(
        "Setting interrupt: node='%s', reason='%s', status='%s'",
        node,
        reason,
        status.value,
    )
    self.interrupted_node = node
    self.interrupt_reason = reason
    self.status = status
    self.interrupt_data = data
ExecutionStatus

Bases: Enum

Status of graph execution.

Attributes:

Name Type Description
COMPLETED
ERROR
INTERRUPTED_AFTER
INTERRUPTED_BEFORE
RUNNING
Source code in pyagenity/state/execution_state.py
18
19
20
21
22
23
24
25
class ExecutionStatus(Enum):
    """Status of graph execution."""

    RUNNING = "running"
    INTERRUPTED_BEFORE = "interrupted_before"
    INTERRUPTED_AFTER = "interrupted_after"
    COMPLETED = "completed"
    ERROR = "error"
Attributes
COMPLETED class-attribute instance-attribute
COMPLETED = 'completed'
ERROR class-attribute instance-attribute
ERROR = 'error'
INTERRUPTED_AFTER class-attribute instance-attribute
INTERRUPTED_AFTER = 'interrupted_after'
INTERRUPTED_BEFORE class-attribute instance-attribute
INTERRUPTED_BEFORE = 'interrupted_before'
RUNNING class-attribute instance-attribute
RUNNING = 'running'
StopRequestStatus

Bases: Enum

Status of graph execution.

Attributes:

Name Type Description
NONE
STOPPED
STOP_REQUESTED
Source code in pyagenity/state/execution_state.py
28
29
30
31
32
33
class StopRequestStatus(Enum):
    """Status of graph execution."""

    NONE = "none"
    STOP_REQUESTED = "stop_requested"
    STOPPED = "stopped"
Attributes
NONE class-attribute instance-attribute
NONE = 'none'
STOPPED class-attribute instance-attribute
STOPPED = 'stopped'
STOP_REQUESTED class-attribute instance-attribute
STOP_REQUESTED = 'stop_requested'
message_context_manager

Message context management for agent state in PyAgenity.

This module provides MessageContextManager, which trims and manages the message history (context) for agent interactions, ensuring efficient context window usage.

Classes:

Name Description
MessageContextManager

Manages the context field for AI interactions.

Attributes:

Name Type Description
S
logger
Attributes
S module-attribute
S = TypeVar('S', bound=AgentState)
logger module-attribute
logger = getLogger(__name__)
Classes
MessageContextManager

Bases: BaseContextManager[S]

Manages the context field for AI interactions.

This class trims the context (message history) based on a maximum number of user messages, ensuring the first message (usually a system prompt) is always preserved. Generic over AgentState or its subclasses.

Methods:

Name Description
__init__

Initialize the MessageContextManager.

atrim_context

Asynchronous version of trim_context.

trim_context

Trim the context in the given AgentState based on the maximum number of user messages.

Attributes:

Name Type Description
max_messages
Source code in pyagenity/state/message_context_manager.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
class MessageContextManager(BaseContextManager[S]):
    """
    Manages the context field for AI interactions.

    This class trims the context (message history) based on a maximum number of user messages,
    ensuring the first message (usually a system prompt) is always preserved.
    Generic over AgentState or its subclasses.
    """

    def __init__(self, max_messages: int = 10) -> None:
        """
        Initialize the MessageContextManager.

        Args:
            max_messages (int): Maximum number of
                user messages to keep in context. Default is 10.
        """
        self.max_messages = max_messages
        logger.debug("Initialized MessageContextManager with max_messages=%d", max_messages)

    def _trim(self, messages: list[Message]) -> list[Message] | None:
        """
        Trim messages keeping system messages and most recent user messages.

        Returns None if no trimming is needed, otherwise returns the trimmed list.
        """
        # check context is empty
        if not messages:
            logger.debug("No messages to trim; context is empty")
            return None

        # Count user messages
        user_message_count = sum(1 for msg in messages if msg.role == "user")

        if user_message_count <= self.max_messages:
            # no trimming needed
            logger.debug(
                "No trimming needed; context is within limits (%d user messages)",
                user_message_count,
            )
            return None

        # Separate system messages (usually at the beginning)
        system_messages = [msg for msg in messages if msg.role == "system"]
        non_system_messages = [msg for msg in messages if msg.role != "system"]

        # Keep only the most recent messages that include max_messages user messages
        final_non_system = []
        user_count = 0

        # Iterate from the end to keep most recent messages
        for msg in reversed(non_system_messages):
            if msg.role == "user":
                if user_count >= self.max_messages:
                    break
                user_count += 1
            final_non_system.insert(0, msg)  # Insert at beginning to maintain order

        # Combine system messages (at start) with trimmed conversation
        trimmed_messages = system_messages + final_non_system

        logger.debug(
            "Trimmed from %d to %d messages (%d user messages kept)",
            len(messages),
            len(trimmed_messages),
            user_count,
        )

        return trimmed_messages

    def trim_context(self, state: S) -> S:
        """
        Trim the context in the given AgentState based on the maximum number of user messages.

        The first message (typically a system prompt) is always preserved. Only the most recent
        user messages up to `max_messages` are kept, along with the first message.

        Args:
            state (AgentState): The agent state containing the context to trim.

        Returns:
            S: The updated agent state with trimmed context.
        """
        messages = state.context
        trimmed_messages = self._trim(messages)
        if trimmed_messages is not None:
            state.context = trimmed_messages
        return state

    async def atrim_context(self, state: S) -> S:
        """
        Asynchronous version of trim_context.

        Args:
            state (AgentState): The agent state containing the context to trim.

        Returns:
            S: The updated agent state with trimmed context.
        """
        messages = state.context
        trimmed_messages = self._trim(messages)
        if trimmed_messages is not None:
            state.context = trimmed_messages
        return state
Attributes
max_messages instance-attribute
max_messages = max_messages
Functions
__init__
__init__(max_messages=10)

Initialize the MessageContextManager.

Parameters:

Name Type Description Default
max_messages int

Maximum number of user messages to keep in context. Default is 10.

10
Source code in pyagenity/state/message_context_manager.py
31
32
33
34
35
36
37
38
39
40
def __init__(self, max_messages: int = 10) -> None:
    """
    Initialize the MessageContextManager.

    Args:
        max_messages (int): Maximum number of
            user messages to keep in context. Default is 10.
    """
    self.max_messages = max_messages
    logger.debug("Initialized MessageContextManager with max_messages=%d", max_messages)
atrim_context async
atrim_context(state)

Asynchronous version of trim_context.

Parameters:

Name Type Description Default
state AgentState

The agent state containing the context to trim.

required

Returns:

Name Type Description
S S

The updated agent state with trimmed context.

Source code in pyagenity/state/message_context_manager.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
async def atrim_context(self, state: S) -> S:
    """
    Asynchronous version of trim_context.

    Args:
        state (AgentState): The agent state containing the context to trim.

    Returns:
        S: The updated agent state with trimmed context.
    """
    messages = state.context
    trimmed_messages = self._trim(messages)
    if trimmed_messages is not None:
        state.context = trimmed_messages
    return state
trim_context
trim_context(state)

Trim the context in the given AgentState based on the maximum number of user messages.

The first message (typically a system prompt) is always preserved. Only the most recent user messages up to max_messages are kept, along with the first message.

Parameters:

Name Type Description Default
state AgentState

The agent state containing the context to trim.

required

Returns:

Name Type Description
S S

The updated agent state with trimmed context.

Source code in pyagenity/state/message_context_manager.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def trim_context(self, state: S) -> S:
    """
    Trim the context in the given AgentState based on the maximum number of user messages.

    The first message (typically a system prompt) is always preserved. Only the most recent
    user messages up to `max_messages` are kept, along with the first message.

    Args:
        state (AgentState): The agent state containing the context to trim.

    Returns:
        S: The updated agent state with trimmed context.
    """
    messages = state.context
    trimmed_messages = self._trim(messages)
    if trimmed_messages is not None:
        state.context = trimmed_messages
    return state

store

Modules:

Name Description
base_store

Simplified Async-First Base Store for PyAgenity Framework

embedding
mem0_store

Mem0 Long-Term Memory Store

qdrant_store

Qdrant Vector Store Implementation for PyAgenity Framework

store_schema

Classes:

Name Description
BaseEmbedding
BaseStore

Simplified async-first base class for memory stores in PyAgenity.

DistanceMetric

Supported distance metrics for vector similarity.

MemoryRecord

Comprehensive memory record for storage (Pydantic model).

MemorySearchResult

Result from a memory search operation (Pydantic model).

MemoryType

Types of memories that can be stored.

OpenAIEmbedding

Attributes

__all__ module-attribute
__all__ = ['BaseEmbedding', 'BaseStore', 'DistanceMetric', 'MemoryRecord', 'MemorySearchResult', 'MemoryType', 'OpenAIEmbedding']

Classes

BaseEmbedding

Bases: ABC

Methods:

Name Description
aembed

Generate embedding for a single text.

aembed_batch

Generate embeddings for a list of texts.

embed

Synchronous wrapper for aembed that runs the async implementation.

embed_batch

Synchronous wrapper for aembed_batch that runs the async implementation.

Attributes:

Name Type Description
dimension int

Synchronous wrapper for that runs the async implementation.

Source code in pyagenity/store/embedding/base_embedding.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class BaseEmbedding(ABC):
    def embed_batch(self, texts: list[str]) -> list[list[float]]:
        """Synchronous wrapper for `aembed_batch` that runs the async implementation."""
        return run_coroutine(self.aembed_batch(texts))

    @abstractmethod
    async def aembed_batch(self, texts: list[str]) -> list[list[float]]:
        """Generate embeddings for a list of texts."""
        # pragma: no cover

    def embed(self, text: str) -> list[float]:
        """Synchronous wrapper for `aembed` that runs the async implementation."""
        return run_coroutine(self.aembed(text))

    @abstractmethod
    async def aembed(self, text: str) -> list[float]:
        """Generate embedding for a single text."""
        raise NotImplementedError

    @property
    @abstractmethod
    def dimension(self) -> int:
        """Synchronous wrapper for that runs the async implementation."""
        raise NotImplementedError
Attributes
dimension abstractmethod property
dimension

Synchronous wrapper for that runs the async implementation.

Functions
aembed abstractmethod async
aembed(text)

Generate embedding for a single text.

Source code in pyagenity/store/embedding/base_embedding.py
20
21
22
23
@abstractmethod
async def aembed(self, text: str) -> list[float]:
    """Generate embedding for a single text."""
    raise NotImplementedError
aembed_batch abstractmethod async
aembed_batch(texts)

Generate embeddings for a list of texts.

Source code in pyagenity/store/embedding/base_embedding.py
11
12
13
@abstractmethod
async def aembed_batch(self, texts: list[str]) -> list[list[float]]:
    """Generate embeddings for a list of texts."""
embed
embed(text)

Synchronous wrapper for aembed that runs the async implementation.

Source code in pyagenity/store/embedding/base_embedding.py
16
17
18
def embed(self, text: str) -> list[float]:
    """Synchronous wrapper for `aembed` that runs the async implementation."""
    return run_coroutine(self.aembed(text))
embed_batch
embed_batch(texts)

Synchronous wrapper for aembed_batch that runs the async implementation.

Source code in pyagenity/store/embedding/base_embedding.py
7
8
9
def embed_batch(self, texts: list[str]) -> list[list[float]]:
    """Synchronous wrapper for `aembed_batch` that runs the async implementation."""
    return run_coroutine(self.aembed_batch(texts))
BaseStore

Bases: ABC

Simplified async-first base class for memory stores in PyAgenity.

This class provides a clean interface that supports: - Vector stores (Qdrant, Pinecone, Chroma, etc.) - Managed memory services (mem0, Zep, etc.) - Graph databases (Neo4j, etc.)

Key Design Principles: - Async-first for better performance - Core CRUD operations only - User and agent-centric operations - Extensible filtering and metadata

Methods:

Name Description
adelete

Delete a memory by ID.

aforget_memory

Delete a memory by for a user or agent.

aget

Get a specific memory by ID.

aget_all

Get a specific memory by user_id.

arelease

Clean up any resources used by the store (override in subclasses if needed).

asearch

Search memories by content similarity.

asetup

Asynchronous setup method for checkpointer.

astore

Add a new memory.

aupdate

Update an existing memory.

delete

Synchronous wrapper for adelete that runs the async implementation.

forget_memory

Delete a memory by for a user or agent.

get

Synchronous wrapper for aget that runs the async implementation.

get_all

Synchronous wrapper for aget that runs the async implementation.

release

Clean up any resources used by the store (override in subclasses if needed).

search

Synchronous wrapper for asearch that runs the async implementation.

setup

Synchronous setup method for checkpointer.

store

Synchronous wrapper for astore that runs the async implementation.

update

Synchronous wrapper for aupdate that runs the async implementation.

Source code in pyagenity/store/base_store.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
class BaseStore(ABC):
    """
    Simplified async-first base class for memory stores in PyAgenity.

    This class provides a clean interface that supports:
    - Vector stores (Qdrant, Pinecone, Chroma, etc.)
    - Managed memory services (mem0, Zep, etc.)
    - Graph databases (Neo4j, etc.)

    Key Design Principles:
    - Async-first for better performance
    - Core CRUD operations only
    - User and agent-centric operations
    - Extensible filtering and metadata
    """

    def setup(self) -> Any:
        """
        Synchronous setup method for checkpointer.

        Returns:
            Any: Implementation-defined setup result.
        """
        return run_coroutine(self.asetup())

    async def asetup(self) -> Any:
        """
        Asynchronous setup method for checkpointer.

        Returns:
            Any: Implementation-defined setup result.
        """
        raise NotImplementedError

    # --- Core Memory Operations ---

    @abstractmethod
    async def astore(
        self,
        config: dict[str, Any],
        content: str | Message,
        memory_type: MemoryType = MemoryType.EPISODIC,
        category: str = "general",
        metadata: dict[str, Any] | None = None,
        **kwargs,
    ) -> str:
        """
        Add a new memory.

        Args:
            content: The memory content
            user_id: User identifier
            agent_id: Agent identifier
            memory_type: Type of memory (episodic, semantic, etc.)
            category: Memory category for organization
            metadata: Additional metadata
            **kwargs: Store-specific parameters

        Returns:
            Memory ID
        """
        raise NotImplementedError

    # --- Sync wrappers ---
    def store(
        self,
        config: dict[str, Any],
        content: str | Message,
        memory_type: MemoryType = MemoryType.EPISODIC,
        category: str = "general",
        metadata: dict[str, Any] | None = None,
        **kwargs,
    ) -> str:
        """Synchronous wrapper for `astore` that runs the async implementation."""
        return run_coroutine(
            self.astore(
                config,
                content,
                memory_type=memory_type,
                category=category,
                metadata=metadata,
                **kwargs,
            )
        )

    @abstractmethod
    async def asearch(
        self,
        config: dict[str, Any],
        query: str,
        memory_type: MemoryType | None = None,
        category: str | None = None,
        limit: int = 10,
        score_threshold: float | None = None,
        filters: dict[str, Any] | None = None,
        retrieval_strategy: RetrievalStrategy = RetrievalStrategy.SIMILARITY,
        distance_metric: DistanceMetric = DistanceMetric.COSINE,
        max_tokens: int = 4000,
        **kwargs,
    ) -> list[MemorySearchResult]:
        """
        Search memories by content similarity.

        Args:
            query: Search query
            user_id: Filter by user
            agent_id: Filter by agent
            memory_type: Filter by memory type
            category: Filter by category
            limit: Maximum results
            score_threshold: Minimum similarity score
            filters: Additional filters
            **kwargs: Store-specific parameters

        Returns:
            List of matching memories
        """
        raise NotImplementedError

    def search(
        self,
        config: dict[str, Any],
        query: str,
        memory_type: MemoryType | None = None,
        category: str | None = None,
        limit: int = 10,
        score_threshold: float | None = None,
        filters: dict[str, Any] | None = None,
        retrieval_strategy: RetrievalStrategy = RetrievalStrategy.SIMILARITY,
        distance_metric: DistanceMetric = DistanceMetric.COSINE,
        max_tokens: int = 4000,
        **kwargs,
    ) -> list[MemorySearchResult]:
        """Synchronous wrapper for `asearch` that runs the async implementation."""
        return run_coroutine(
            self.asearch(
                config,
                query,
                memory_type=memory_type,
                category=category,
                limit=limit,
                score_threshold=score_threshold,
                filters=filters,
                retrieval_strategy=retrieval_strategy,
                distance_metric=distance_metric,
                max_tokens=max_tokens,
                **kwargs,
            )
        )

    @abstractmethod
    async def aget(
        self,
        config: dict[str, Any],
        memory_id: str,
        **kwargs,
    ) -> MemorySearchResult | None:
        """Get a specific memory by ID."""
        raise NotImplementedError

    @abstractmethod
    async def aget_all(
        self,
        config: dict[str, Any],
        limit: int = 100,
        **kwargs,
    ) -> list[MemorySearchResult]:
        """Get a specific memory by user_id."""
        raise NotImplementedError

    def get(self, config: dict[str, Any], memory_id: str, **kwargs) -> MemorySearchResult | None:
        """Synchronous wrapper for `aget` that runs the async implementation."""
        return run_coroutine(self.aget(config, memory_id, **kwargs))

    def get_all(
        self,
        config: dict[str, Any],
        limit: int = 100,
        **kwargs,
    ) -> list[MemorySearchResult]:
        """Synchronous wrapper for `aget` that runs the async implementation."""
        return run_coroutine(self.aget_all(config, limit=limit, **kwargs))

    @abstractmethod
    async def aupdate(
        self,
        config: dict[str, Any],
        memory_id: str,
        content: str | Message,
        metadata: dict[str, Any] | None = None,
        **kwargs,
    ) -> Any:
        """
        Update an existing memory.

        Args:
            memory_id: ID of memory to update
            content: New content (optional)
            metadata: New/additional metadata (optional)
            **kwargs: Store-specific parameters
        """
        raise NotImplementedError

    def update(
        self,
        config: dict[str, Any],
        memory_id: str,
        content: str | Message,
        metadata: dict[str, Any] | None = None,
        **kwargs,
    ) -> Any:
        """Synchronous wrapper for `aupdate` that runs the async implementation."""
        return run_coroutine(self.aupdate(config, memory_id, content, metadata=metadata, **kwargs))

    @abstractmethod
    async def adelete(
        self,
        config: dict[str, Any],
        memory_id: str,
        **kwargs,
    ) -> Any:
        """Delete a memory by ID."""
        raise NotImplementedError

    def delete(self, config: dict[str, Any], memory_id: str, **kwargs) -> None:
        """Synchronous wrapper for `adelete` that runs the async implementation."""
        return run_coroutine(self.adelete(config, memory_id, **kwargs))

    @abstractmethod
    async def aforget_memory(
        self,
        config: dict[str, Any],
        **kwargs,
    ) -> Any:
        """Delete a memory by for a user or agent."""
        raise NotImplementedError

    def forget_memory(
        self,
        config: dict[str, Any],
        **kwargs,
    ) -> Any:
        """Delete a memory by for a user or agent."""
        return run_coroutine(self.aforget_memory(config, **kwargs))

    # --- Cleanup and Resource Management ---

    async def arelease(self) -> None:
        """Clean up any resources used by the store (override in subclasses if needed)."""
        raise NotImplementedError

    def release(self) -> None:
        """Clean up any resources used by the store (override in subclasses if needed)."""
        return run_coroutine(self.arelease())
Functions
adelete abstractmethod async
adelete(config, memory_id, **kwargs)

Delete a memory by ID.

Source code in pyagenity/store/base_store.py
237
238
239
240
241
242
243
244
245
@abstractmethod
async def adelete(
    self,
    config: dict[str, Any],
    memory_id: str,
    **kwargs,
) -> Any:
    """Delete a memory by ID."""
    raise NotImplementedError
aforget_memory abstractmethod async
aforget_memory(config, **kwargs)

Delete a memory by for a user or agent.

Source code in pyagenity/store/base_store.py
251
252
253
254
255
256
257
258
@abstractmethod
async def aforget_memory(
    self,
    config: dict[str, Any],
    **kwargs,
) -> Any:
    """Delete a memory by for a user or agent."""
    raise NotImplementedError
aget abstractmethod async
aget(config, memory_id, **kwargs)

Get a specific memory by ID.

Source code in pyagenity/store/base_store.py
173
174
175
176
177
178
179
180
181
@abstractmethod
async def aget(
    self,
    config: dict[str, Any],
    memory_id: str,
    **kwargs,
) -> MemorySearchResult | None:
    """Get a specific memory by ID."""
    raise NotImplementedError
aget_all abstractmethod async
aget_all(config, limit=100, **kwargs)

Get a specific memory by user_id.

Source code in pyagenity/store/base_store.py
183
184
185
186
187
188
189
190
191
@abstractmethod
async def aget_all(
    self,
    config: dict[str, Any],
    limit: int = 100,
    **kwargs,
) -> list[MemorySearchResult]:
    """Get a specific memory by user_id."""
    raise NotImplementedError
arelease async
arelease()

Clean up any resources used by the store (override in subclasses if needed).

Source code in pyagenity/store/base_store.py
270
271
272
async def arelease(self) -> None:
    """Clean up any resources used by the store (override in subclasses if needed)."""
    raise NotImplementedError
asearch abstractmethod async
asearch(config, query, memory_type=None, category=None, limit=10, score_threshold=None, filters=None, retrieval_strategy=RetrievalStrategy.SIMILARITY, distance_metric=DistanceMetric.COSINE, max_tokens=4000, **kwargs)

Search memories by content similarity.

Parameters:

Name Type Description Default
query str

Search query

required
user_id

Filter by user

required
agent_id

Filter by agent

required
memory_type MemoryType | None

Filter by memory type

None
category str | None

Filter by category

None
limit int

Maximum results

10
score_threshold float | None

Minimum similarity score

None
filters dict[str, Any] | None

Additional filters

None
**kwargs

Store-specific parameters

{}

Returns:

Type Description
list[MemorySearchResult]

List of matching memories

Source code in pyagenity/store/base_store.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
@abstractmethod
async def asearch(
    self,
    config: dict[str, Any],
    query: str,
    memory_type: MemoryType | None = None,
    category: str | None = None,
    limit: int = 10,
    score_threshold: float | None = None,
    filters: dict[str, Any] | None = None,
    retrieval_strategy: RetrievalStrategy = RetrievalStrategy.SIMILARITY,
    distance_metric: DistanceMetric = DistanceMetric.COSINE,
    max_tokens: int = 4000,
    **kwargs,
) -> list[MemorySearchResult]:
    """
    Search memories by content similarity.

    Args:
        query: Search query
        user_id: Filter by user
        agent_id: Filter by agent
        memory_type: Filter by memory type
        category: Filter by category
        limit: Maximum results
        score_threshold: Minimum similarity score
        filters: Additional filters
        **kwargs: Store-specific parameters

    Returns:
        List of matching memories
    """
    raise NotImplementedError
asetup async
asetup()

Asynchronous setup method for checkpointer.

Returns:

Name Type Description
Any Any

Implementation-defined setup result.

Source code in pyagenity/store/base_store.py
48
49
50
51
52
53
54
55
async def asetup(self) -> Any:
    """
    Asynchronous setup method for checkpointer.

    Returns:
        Any: Implementation-defined setup result.
    """
    raise NotImplementedError
astore abstractmethod async
astore(config, content, memory_type=MemoryType.EPISODIC, category='general', metadata=None, **kwargs)

Add a new memory.

Parameters:

Name Type Description Default
content str | Message

The memory content

required
user_id

User identifier

required
agent_id

Agent identifier

required
memory_type MemoryType

Type of memory (episodic, semantic, etc.)

EPISODIC
category str

Memory category for organization

'general'
metadata dict[str, Any] | None

Additional metadata

None
**kwargs

Store-specific parameters

{}

Returns:

Type Description
str

Memory ID

Source code in pyagenity/store/base_store.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
@abstractmethod
async def astore(
    self,
    config: dict[str, Any],
    content: str | Message,
    memory_type: MemoryType = MemoryType.EPISODIC,
    category: str = "general",
    metadata: dict[str, Any] | None = None,
    **kwargs,
) -> str:
    """
    Add a new memory.

    Args:
        content: The memory content
        user_id: User identifier
        agent_id: Agent identifier
        memory_type: Type of memory (episodic, semantic, etc.)
        category: Memory category for organization
        metadata: Additional metadata
        **kwargs: Store-specific parameters

    Returns:
        Memory ID
    """
    raise NotImplementedError
aupdate abstractmethod async
aupdate(config, memory_id, content, metadata=None, **kwargs)

Update an existing memory.

Parameters:

Name Type Description Default
memory_id str

ID of memory to update

required
content str | Message

New content (optional)

required
metadata dict[str, Any] | None

New/additional metadata (optional)

None
**kwargs

Store-specific parameters

{}
Source code in pyagenity/store/base_store.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
@abstractmethod
async def aupdate(
    self,
    config: dict[str, Any],
    memory_id: str,
    content: str | Message,
    metadata: dict[str, Any] | None = None,
    **kwargs,
) -> Any:
    """
    Update an existing memory.

    Args:
        memory_id: ID of memory to update
        content: New content (optional)
        metadata: New/additional metadata (optional)
        **kwargs: Store-specific parameters
    """
    raise NotImplementedError
delete
delete(config, memory_id, **kwargs)

Synchronous wrapper for adelete that runs the async implementation.

Source code in pyagenity/store/base_store.py
247
248
249
def delete(self, config: dict[str, Any], memory_id: str, **kwargs) -> None:
    """Synchronous wrapper for `adelete` that runs the async implementation."""
    return run_coroutine(self.adelete(config, memory_id, **kwargs))
forget_memory
forget_memory(config, **kwargs)

Delete a memory by for a user or agent.

Source code in pyagenity/store/base_store.py
260
261
262
263
264
265
266
def forget_memory(
    self,
    config: dict[str, Any],
    **kwargs,
) -> Any:
    """Delete a memory by for a user or agent."""
    return run_coroutine(self.aforget_memory(config, **kwargs))
get
get(config, memory_id, **kwargs)

Synchronous wrapper for aget that runs the async implementation.

Source code in pyagenity/store/base_store.py
193
194
195
def get(self, config: dict[str, Any], memory_id: str, **kwargs) -> MemorySearchResult | None:
    """Synchronous wrapper for `aget` that runs the async implementation."""
    return run_coroutine(self.aget(config, memory_id, **kwargs))
get_all
get_all(config, limit=100, **kwargs)

Synchronous wrapper for aget that runs the async implementation.

Source code in pyagenity/store/base_store.py
197
198
199
200
201
202
203
204
def get_all(
    self,
    config: dict[str, Any],
    limit: int = 100,
    **kwargs,
) -> list[MemorySearchResult]:
    """Synchronous wrapper for `aget` that runs the async implementation."""
    return run_coroutine(self.aget_all(config, limit=limit, **kwargs))
release
release()

Clean up any resources used by the store (override in subclasses if needed).

Source code in pyagenity/store/base_store.py
274
275
276
def release(self) -> None:
    """Clean up any resources used by the store (override in subclasses if needed)."""
    return run_coroutine(self.arelease())
search
search(config, query, memory_type=None, category=None, limit=10, score_threshold=None, filters=None, retrieval_strategy=RetrievalStrategy.SIMILARITY, distance_metric=DistanceMetric.COSINE, max_tokens=4000, **kwargs)

Synchronous wrapper for asearch that runs the async implementation.

Source code in pyagenity/store/base_store.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def search(
    self,
    config: dict[str, Any],
    query: str,
    memory_type: MemoryType | None = None,
    category: str | None = None,
    limit: int = 10,
    score_threshold: float | None = None,
    filters: dict[str, Any] | None = None,
    retrieval_strategy: RetrievalStrategy = RetrievalStrategy.SIMILARITY,
    distance_metric: DistanceMetric = DistanceMetric.COSINE,
    max_tokens: int = 4000,
    **kwargs,
) -> list[MemorySearchResult]:
    """Synchronous wrapper for `asearch` that runs the async implementation."""
    return run_coroutine(
        self.asearch(
            config,
            query,
            memory_type=memory_type,
            category=category,
            limit=limit,
            score_threshold=score_threshold,
            filters=filters,
            retrieval_strategy=retrieval_strategy,
            distance_metric=distance_metric,
            max_tokens=max_tokens,
            **kwargs,
        )
    )
setup
setup()

Synchronous setup method for checkpointer.

Returns:

Name Type Description
Any Any

Implementation-defined setup result.

Source code in pyagenity/store/base_store.py
39
40
41
42
43
44
45
46
def setup(self) -> Any:
    """
    Synchronous setup method for checkpointer.

    Returns:
        Any: Implementation-defined setup result.
    """
    return run_coroutine(self.asetup())
store
store(config, content, memory_type=MemoryType.EPISODIC, category='general', metadata=None, **kwargs)

Synchronous wrapper for astore that runs the async implementation.

Source code in pyagenity/store/base_store.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def store(
    self,
    config: dict[str, Any],
    content: str | Message,
    memory_type: MemoryType = MemoryType.EPISODIC,
    category: str = "general",
    metadata: dict[str, Any] | None = None,
    **kwargs,
) -> str:
    """Synchronous wrapper for `astore` that runs the async implementation."""
    return run_coroutine(
        self.astore(
            config,
            content,
            memory_type=memory_type,
            category=category,
            metadata=metadata,
            **kwargs,
        )
    )
update
update(config, memory_id, content, metadata=None, **kwargs)

Synchronous wrapper for aupdate that runs the async implementation.

Source code in pyagenity/store/base_store.py
226
227
228
229
230
231
232
233
234
235
def update(
    self,
    config: dict[str, Any],
    memory_id: str,
    content: str | Message,
    metadata: dict[str, Any] | None = None,
    **kwargs,
) -> Any:
    """Synchronous wrapper for `aupdate` that runs the async implementation."""
    return run_coroutine(self.aupdate(config, memory_id, content, metadata=metadata, **kwargs))
DistanceMetric

Bases: Enum

Supported distance metrics for vector similarity.

Attributes:

Name Type Description
COSINE
DOT_PRODUCT
EUCLIDEAN
MANHATTAN
Source code in pyagenity/store/store_schema.py
21
22
23
24
25
26
27
class DistanceMetric(Enum):
    """Supported distance metrics for vector similarity."""

    COSINE = "cosine"
    EUCLIDEAN = "euclidean"
    DOT_PRODUCT = "dot_product"
    MANHATTAN = "manhattan"
Attributes
COSINE class-attribute instance-attribute
COSINE = 'cosine'
DOT_PRODUCT class-attribute instance-attribute
DOT_PRODUCT = 'dot_product'
EUCLIDEAN class-attribute instance-attribute
EUCLIDEAN = 'euclidean'
MANHATTAN class-attribute instance-attribute
MANHATTAN = 'manhattan'
MemoryRecord

Bases: BaseModel

Comprehensive memory record for storage (Pydantic model).

Methods:

Name Description
from_message
validate_vector

Attributes:

Name Type Description
category str
content str
id str
memory_type MemoryType
metadata dict[str, Any]
thread_id str | None
timestamp datetime | None
user_id str | None
vector list[float] | None
Source code in pyagenity/store/store_schema.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
class MemoryRecord(BaseModel):
    """Comprehensive memory record for storage (Pydantic model)."""

    id: str = Field(default_factory=lambda: str(uuid4()))
    content: str
    user_id: str | None = None
    thread_id: str | None = None
    memory_type: MemoryType = Field(default=MemoryType.EPISODIC)
    metadata: dict[str, Any] = Field(default_factory=dict)
    category: str = Field(default="general")
    vector: list[float] | None = None
    timestamp: datetime | None = Field(default_factory=datetime.now)

    @field_validator("vector")
    @classmethod
    def validate_vector(cls, v):
        if v is not None and (
            not isinstance(v, list) or any(not isinstance(x, (int | float)) for x in v)
        ):
            raise ValueError("vector must be list[float] or None")
        return v

    @classmethod
    def from_message(
        cls,
        message: Message,
        user_id: str | None = None,
        thread_id: str | None = None,
        vector: list[float] | None = None,
        additional_metadata: dict[str, Any] | None = None,
    ) -> "MemoryRecord":
        content = message.text()
        metadata = {
            "role": message.role,
            "message_id": str(message.message_id),
            "timestamp": message.timestamp.isoformat() if message.timestamp else None,
            "has_tool_calls": bool(message.tools_calls),
            "has_reasoning": bool(message.reasoning),
            "token_usage": message.usages.model_dump() if message.usages else None,
            **(additional_metadata or {}),
        }
        return cls(
            content=content,
            user_id=user_id,
            thread_id=thread_id,
            memory_type=MemoryType.EPISODIC,
            metadata=metadata,
            vector=vector,
        )
Attributes
category class-attribute instance-attribute
category = Field(default='general')
content instance-attribute
content
id class-attribute instance-attribute
id = Field(default_factory=lambda: str(uuid4()))
memory_type class-attribute instance-attribute
memory_type = Field(default=EPISODIC)
metadata class-attribute instance-attribute
metadata = Field(default_factory=dict)
thread_id class-attribute instance-attribute
thread_id = None
timestamp class-attribute instance-attribute
timestamp = Field(default_factory=now)
user_id class-attribute instance-attribute
user_id = None
vector class-attribute instance-attribute
vector = None
Functions
from_message classmethod
from_message(message, user_id=None, thread_id=None, vector=None, additional_metadata=None)
Source code in pyagenity/store/store_schema.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
@classmethod
def from_message(
    cls,
    message: Message,
    user_id: str | None = None,
    thread_id: str | None = None,
    vector: list[float] | None = None,
    additional_metadata: dict[str, Any] | None = None,
) -> "MemoryRecord":
    content = message.text()
    metadata = {
        "role": message.role,
        "message_id": str(message.message_id),
        "timestamp": message.timestamp.isoformat() if message.timestamp else None,
        "has_tool_calls": bool(message.tools_calls),
        "has_reasoning": bool(message.reasoning),
        "token_usage": message.usages.model_dump() if message.usages else None,
        **(additional_metadata or {}),
    }
    return cls(
        content=content,
        user_id=user_id,
        thread_id=thread_id,
        memory_type=MemoryType.EPISODIC,
        metadata=metadata,
        vector=vector,
    )
validate_vector classmethod
validate_vector(v)
Source code in pyagenity/store/store_schema.py
78
79
80
81
82
83
84
85
@field_validator("vector")
@classmethod
def validate_vector(cls, v):
    if v is not None and (
        not isinstance(v, list) or any(not isinstance(x, (int | float)) for x in v)
    ):
        raise ValueError("vector must be list[float] or None")
    return v
MemorySearchResult

Bases: BaseModel

Result from a memory search operation (Pydantic model).

Methods:

Name Description
validate_vector

Attributes:

Name Type Description
content str
id str
memory_type MemoryType
metadata dict[str, Any]
score float
thread_id str | None
timestamp datetime | None
user_id str | None
vector list[float] | None
Source code in pyagenity/store/store_schema.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class MemorySearchResult(BaseModel):
    """Result from a memory search operation (Pydantic model)."""

    id: str = Field(default_factory=lambda: str(uuid4()))
    content: str = Field(default="", description="Primary textual content of the memory")
    score: float = Field(default=0.0, ge=0.0, description="Similarity / relevance score")
    memory_type: MemoryType = Field(default=MemoryType.EPISODIC)
    metadata: dict[str, Any] = Field(default_factory=dict)
    vector: list[float] | None = Field(default=None)
    user_id: str | None = None
    thread_id: str | None = None
    timestamp: datetime | None = Field(default_factory=datetime.now)

    @field_validator("vector")
    @classmethod
    def validate_vector(cls, v):
        if v is not None and (
            not isinstance(v, list) or any(not isinstance(x, (int | float)) for x in v)
        ):
            raise ValueError("vector must be list[float] or None")
        return v
Attributes
content class-attribute instance-attribute
content = Field(default='', description='Primary textual content of the memory')
id class-attribute instance-attribute
id = Field(default_factory=lambda: str(uuid4()))
memory_type class-attribute instance-attribute
memory_type = Field(default=EPISODIC)
metadata class-attribute instance-attribute
metadata = Field(default_factory=dict)
score class-attribute instance-attribute
score = Field(default=0.0, ge=0.0, description='Similarity / relevance score')
thread_id class-attribute instance-attribute
thread_id = None
timestamp class-attribute instance-attribute
timestamp = Field(default_factory=now)
user_id class-attribute instance-attribute
user_id = None
vector class-attribute instance-attribute
vector = Field(default=None)
Functions
validate_vector classmethod
validate_vector(v)
Source code in pyagenity/store/store_schema.py
55
56
57
58
59
60
61
62
@field_validator("vector")
@classmethod
def validate_vector(cls, v):
    if v is not None and (
        not isinstance(v, list) or any(not isinstance(x, (int | float)) for x in v)
    ):
        raise ValueError("vector must be list[float] or None")
    return v
MemoryType

Bases: Enum

Types of memories that can be stored.

Attributes:

Name Type Description
CUSTOM
DECLARATIVE
ENTITY
EPISODIC
PROCEDURAL
RELATIONSHIP
SEMANTIC
Source code in pyagenity/store/store_schema.py
30
31
32
33
34
35
36
37
38
39
class MemoryType(Enum):
    """Types of memories that can be stored."""

    EPISODIC = "episodic"  # Conversation memories
    SEMANTIC = "semantic"  # Facts and knowledge
    PROCEDURAL = "procedural"  # How-to knowledge
    ENTITY = "entity"  # Entity-based memories
    RELATIONSHIP = "relationship"  # Entity relationships
    CUSTOM = "custom"  # Custom memory types
    DECLARATIVE = "declarative"  # Explicit facts and events
Attributes
CUSTOM class-attribute instance-attribute
CUSTOM = 'custom'
DECLARATIVE class-attribute instance-attribute
DECLARATIVE = 'declarative'
ENTITY class-attribute instance-attribute
ENTITY = 'entity'
EPISODIC class-attribute instance-attribute
EPISODIC = 'episodic'
PROCEDURAL class-attribute instance-attribute
PROCEDURAL = 'procedural'
RELATIONSHIP class-attribute instance-attribute
RELATIONSHIP = 'relationship'
SEMANTIC class-attribute instance-attribute
SEMANTIC = 'semantic'
OpenAIEmbedding

Bases: BaseEmbedding

Methods:

Name Description
__init__
aembed
aembed_batch
embed

Synchronous wrapper for aembed that runs the async implementation.

embed_batch

Synchronous wrapper for aembed_batch that runs the async implementation.

Attributes:

Name Type Description
api_key
client
dimension int
model
Source code in pyagenity/store/embedding/openai_embedding.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class OpenAIEmbedding(BaseEmbedding):
    def __init__(
        self,
        model: str = "text-embedding-3-small",
        OPENAI_API_KEY: str | None = None,
    ) -> None:
        if not HAS_OPENAI:
            raise ImportError(
                "The 'openai' package is required for OpenAIEmbedding. "
                "Please install it via 'pip install openai'."
            )
        self.model = model
        if OPENAI_API_KEY:
            self.api_key = OPENAI_API_KEY
        elif "OPENAI_API_KEY" in os.environ:
            self.api_key = os.environ["OPENAI_API_KEY"]
        else:
            raise ValueError(
                "OpenAI API key must be provided via parameter or OPENAI_API_KEY env var"
            )

        self.client = AsyncOpenAI(
            api_key=self.api_key,
        )

    async def aembed_batch(self, texts: list[str]) -> list[list[float]]:
        try:
            response = await self.client.embeddings.create(
                input=texts,
                model=self.model,
            )
            return [data.embedding for data in response.data]
        except OpenAIError as e:
            raise RuntimeError(f"OpenAI API error: {e}") from e

    async def aembed(self, text: str) -> list[float]:
        try:
            response = await self.client.embeddings.create(
                input=text,
                model=self.model,
            )
            return response.data[0].embedding if response.data else []
        except OpenAIError as e:
            raise RuntimeError(f"OpenAI API error: {e}") from e

    @property
    def dimension(self) -> int:
        model_dimensions = {
            "text-embedding-3-small": 1536,
            "text-embedding-3-large": 1536,
            "text-embedding-3-xl": 1536,
            "text-embedding-4-base": 8192,
            "text-embedding-4-large": 8192,
        }
        if self.model in model_dimensions:
            return model_dimensions[self.model]
        raise ValueError(f"Unknown model '{self.model}'. Cannot determine dimension.")
Attributes
api_key instance-attribute
api_key = OPENAI_API_KEY
client instance-attribute
client = AsyncOpenAI(api_key=api_key)
dimension property
dimension
model instance-attribute
model = model
Functions
__init__
__init__(model='text-embedding-3-small', OPENAI_API_KEY=None)
Source code in pyagenity/store/embedding/openai_embedding.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def __init__(
    self,
    model: str = "text-embedding-3-small",
    OPENAI_API_KEY: str | None = None,
) -> None:
    if not HAS_OPENAI:
        raise ImportError(
            "The 'openai' package is required for OpenAIEmbedding. "
            "Please install it via 'pip install openai'."
        )
    self.model = model
    if OPENAI_API_KEY:
        self.api_key = OPENAI_API_KEY
    elif "OPENAI_API_KEY" in os.environ:
        self.api_key = os.environ["OPENAI_API_KEY"]
    else:
        raise ValueError(
            "OpenAI API key must be provided via parameter or OPENAI_API_KEY env var"
        )

    self.client = AsyncOpenAI(
        api_key=self.api_key,
    )
aembed async
aembed(text)
Source code in pyagenity/store/embedding/openai_embedding.py
54
55
56
57
58
59
60
61
62
async def aembed(self, text: str) -> list[float]:
    try:
        response = await self.client.embeddings.create(
            input=text,
            model=self.model,
        )
        return response.data[0].embedding if response.data else []
    except OpenAIError as e:
        raise RuntimeError(f"OpenAI API error: {e}") from e
aembed_batch async
aembed_batch(texts)
Source code in pyagenity/store/embedding/openai_embedding.py
44
45
46
47
48
49
50
51
52
async def aembed_batch(self, texts: list[str]) -> list[list[float]]:
    try:
        response = await self.client.embeddings.create(
            input=texts,
            model=self.model,
        )
        return [data.embedding for data in response.data]
    except OpenAIError as e:
        raise RuntimeError(f"OpenAI API error: {e}") from e
embed
embed(text)

Synchronous wrapper for aembed that runs the async implementation.

Source code in pyagenity/store/embedding/base_embedding.py
16
17
18
def embed(self, text: str) -> list[float]:
    """Synchronous wrapper for `aembed` that runs the async implementation."""
    return run_coroutine(self.aembed(text))
embed_batch
embed_batch(texts)

Synchronous wrapper for aembed_batch that runs the async implementation.

Source code in pyagenity/store/embedding/base_embedding.py
7
8
9
def embed_batch(self, texts: list[str]) -> list[list[float]]:
    """Synchronous wrapper for `aembed_batch` that runs the async implementation."""
    return run_coroutine(self.aembed_batch(texts))

Functions

Modules

base_store

Simplified Async-First Base Store for PyAgenity Framework

This module provides a clean, modern interface for memory stores with: - Async-first design for better performance - Core CRUD operations only - Message-specific convenience methods - Extensible for different backends (vector stores, managed services, etc.)

Classes:

Name Description
BaseStore

Simplified async-first base class for memory stores in PyAgenity.

Attributes:

Name Type Description
logger
Attributes
logger module-attribute
logger = getLogger(__name__)
Classes
BaseStore

Bases: ABC

Simplified async-first base class for memory stores in PyAgenity.

This class provides a clean interface that supports: - Vector stores (Qdrant, Pinecone, Chroma, etc.) - Managed memory services (mem0, Zep, etc.) - Graph databases (Neo4j, etc.)

Key Design Principles: - Async-first for better performance - Core CRUD operations only - User and agent-centric operations - Extensible filtering and metadata

Methods:

Name Description
adelete

Delete a memory by ID.

aforget_memory

Delete a memory by for a user or agent.

aget

Get a specific memory by ID.

aget_all

Get a specific memory by user_id.

arelease

Clean up any resources used by the store (override in subclasses if needed).

asearch

Search memories by content similarity.

asetup

Asynchronous setup method for checkpointer.

astore

Add a new memory.

aupdate

Update an existing memory.

delete

Synchronous wrapper for adelete that runs the async implementation.

forget_memory

Delete a memory by for a user or agent.

get

Synchronous wrapper for aget that runs the async implementation.

get_all

Synchronous wrapper for aget that runs the async implementation.

release

Clean up any resources used by the store (override in subclasses if needed).

search

Synchronous wrapper for asearch that runs the async implementation.

setup

Synchronous setup method for checkpointer.

store

Synchronous wrapper for astore that runs the async implementation.

update

Synchronous wrapper for aupdate that runs the async implementation.

Source code in pyagenity/store/base_store.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
class BaseStore(ABC):
    """
    Simplified async-first base class for memory stores in PyAgenity.

    This class provides a clean interface that supports:
    - Vector stores (Qdrant, Pinecone, Chroma, etc.)
    - Managed memory services (mem0, Zep, etc.)
    - Graph databases (Neo4j, etc.)

    Key Design Principles:
    - Async-first for better performance
    - Core CRUD operations only
    - User and agent-centric operations
    - Extensible filtering and metadata
    """

    def setup(self) -> Any:
        """
        Synchronous setup method for checkpointer.

        Returns:
            Any: Implementation-defined setup result.
        """
        return run_coroutine(self.asetup())

    async def asetup(self) -> Any:
        """
        Asynchronous setup method for checkpointer.

        Returns:
            Any: Implementation-defined setup result.
        """
        raise NotImplementedError

    # --- Core Memory Operations ---

    @abstractmethod
    async def astore(
        self,
        config: dict[str, Any],
        content: str | Message,
        memory_type: MemoryType = MemoryType.EPISODIC,
        category: str = "general",
        metadata: dict[str, Any] | None = None,
        **kwargs,
    ) -> str:
        """
        Add a new memory.

        Args:
            content: The memory content
            user_id: User identifier
            agent_id: Agent identifier
            memory_type: Type of memory (episodic, semantic, etc.)
            category: Memory category for organization
            metadata: Additional metadata
            **kwargs: Store-specific parameters

        Returns:
            Memory ID
        """
        raise NotImplementedError

    # --- Sync wrappers ---
    def store(
        self,
        config: dict[str, Any],
        content: str | Message,
        memory_type: MemoryType = MemoryType.EPISODIC,
        category: str = "general",
        metadata: dict[str, Any] | None = None,
        **kwargs,
    ) -> str:
        """Synchronous wrapper for `astore` that runs the async implementation."""
        return run_coroutine(
            self.astore(
                config,
                content,
                memory_type=memory_type,
                category=category,
                metadata=metadata,
                **kwargs,
            )
        )

    @abstractmethod
    async def asearch(
        self,
        config: dict[str, Any],
        query: str,
        memory_type: MemoryType | None = None,
        category: str | None = None,
        limit: int = 10,
        score_threshold: float | None = None,
        filters: dict[str, Any] | None = None,
        retrieval_strategy: RetrievalStrategy = RetrievalStrategy.SIMILARITY,
        distance_metric: DistanceMetric = DistanceMetric.COSINE,
        max_tokens: int = 4000,
        **kwargs,
    ) -> list[MemorySearchResult]:
        """
        Search memories by content similarity.

        Args:
            query: Search query
            user_id: Filter by user
            agent_id: Filter by agent
            memory_type: Filter by memory type
            category: Filter by category
            limit: Maximum results
            score_threshold: Minimum similarity score
            filters: Additional filters
            **kwargs: Store-specific parameters

        Returns:
            List of matching memories
        """
        raise NotImplementedError

    def search(
        self,
        config: dict[str, Any],
        query: str,
        memory_type: MemoryType | None = None,
        category: str | None = None,
        limit: int = 10,
        score_threshold: float | None = None,
        filters: dict[str, Any] | None = None,
        retrieval_strategy: RetrievalStrategy = RetrievalStrategy.SIMILARITY,
        distance_metric: DistanceMetric = DistanceMetric.COSINE,
        max_tokens: int = 4000,
        **kwargs,
    ) -> list[MemorySearchResult]:
        """Synchronous wrapper for `asearch` that runs the async implementation."""
        return run_coroutine(
            self.asearch(
                config,
                query,
                memory_type=memory_type,
                category=category,
                limit=limit,
                score_threshold=score_threshold,
                filters=filters,
                retrieval_strategy=retrieval_strategy,
                distance_metric=distance_metric,
                max_tokens=max_tokens,
                **kwargs,
            )
        )

    @abstractmethod
    async def aget(
        self,
        config: dict[str, Any],
        memory_id: str,
        **kwargs,
    ) -> MemorySearchResult | None:
        """Get a specific memory by ID."""
        raise NotImplementedError

    @abstractmethod
    async def aget_all(
        self,
        config: dict[str, Any],
        limit: int = 100,
        **kwargs,
    ) -> list[MemorySearchResult]:
        """Get a specific memory by user_id."""
        raise NotImplementedError

    def get(self, config: dict[str, Any], memory_id: str, **kwargs) -> MemorySearchResult | None:
        """Synchronous wrapper for `aget` that runs the async implementation."""
        return run_coroutine(self.aget(config, memory_id, **kwargs))

    def get_all(
        self,
        config: dict[str, Any],
        limit: int = 100,
        **kwargs,
    ) -> list[MemorySearchResult]:
        """Synchronous wrapper for `aget` that runs the async implementation."""
        return run_coroutine(self.aget_all(config, limit=limit, **kwargs))

    @abstractmethod
    async def aupdate(
        self,
        config: dict[str, Any],
        memory_id: str,
        content: str | Message,
        metadata: dict[str, Any] | None = None,
        **kwargs,
    ) -> Any:
        """
        Update an existing memory.

        Args:
            memory_id: ID of memory to update
            content: New content (optional)
            metadata: New/additional metadata (optional)
            **kwargs: Store-specific parameters
        """
        raise NotImplementedError

    def update(
        self,
        config: dict[str, Any],
        memory_id: str,
        content: str | Message,
        metadata: dict[str, Any] | None = None,
        **kwargs,
    ) -> Any:
        """Synchronous wrapper for `aupdate` that runs the async implementation."""
        return run_coroutine(self.aupdate(config, memory_id, content, metadata=metadata, **kwargs))

    @abstractmethod
    async def adelete(
        self,
        config: dict[str, Any],
        memory_id: str,
        **kwargs,
    ) -> Any:
        """Delete a memory by ID."""
        raise NotImplementedError

    def delete(self, config: dict[str, Any], memory_id: str, **kwargs) -> None:
        """Synchronous wrapper for `adelete` that runs the async implementation."""
        return run_coroutine(self.adelete(config, memory_id, **kwargs))

    @abstractmethod
    async def aforget_memory(
        self,
        config: dict[str, Any],
        **kwargs,
    ) -> Any:
        """Delete a memory by for a user or agent."""
        raise NotImplementedError

    def forget_memory(
        self,
        config: dict[str, Any],
        **kwargs,
    ) -> Any:
        """Delete a memory by for a user or agent."""
        return run_coroutine(self.aforget_memory(config, **kwargs))

    # --- Cleanup and Resource Management ---

    async def arelease(self) -> None:
        """Clean up any resources used by the store (override in subclasses if needed)."""
        raise NotImplementedError

    def release(self) -> None:
        """Clean up any resources used by the store (override in subclasses if needed)."""
        return run_coroutine(self.arelease())
Functions
adelete abstractmethod async
adelete(config, memory_id, **kwargs)

Delete a memory by ID.

Source code in pyagenity/store/base_store.py
237
238
239
240
241
242
243
244
245
@abstractmethod
async def adelete(
    self,
    config: dict[str, Any],
    memory_id: str,
    **kwargs,
) -> Any:
    """Delete a memory by ID."""
    raise NotImplementedError
aforget_memory abstractmethod async
aforget_memory(config, **kwargs)

Delete a memory by for a user or agent.

Source code in pyagenity/store/base_store.py
251
252
253
254
255
256
257
258
@abstractmethod
async def aforget_memory(
    self,
    config: dict[str, Any],
    **kwargs,
) -> Any:
    """Delete a memory by for a user or agent."""
    raise NotImplementedError
aget abstractmethod async
aget(config, memory_id, **kwargs)

Get a specific memory by ID.

Source code in pyagenity/store/base_store.py
173
174
175
176
177
178
179
180
181
@abstractmethod
async def aget(
    self,
    config: dict[str, Any],
    memory_id: str,
    **kwargs,
) -> MemorySearchResult | None:
    """Get a specific memory by ID."""
    raise NotImplementedError
aget_all abstractmethod async
aget_all(config, limit=100, **kwargs)

Get a specific memory by user_id.

Source code in pyagenity/store/base_store.py
183
184
185
186
187
188
189
190
191
@abstractmethod
async def aget_all(
    self,
    config: dict[str, Any],
    limit: int = 100,
    **kwargs,
) -> list[MemorySearchResult]:
    """Get a specific memory by user_id."""
    raise NotImplementedError
arelease async
arelease()

Clean up any resources used by the store (override in subclasses if needed).

Source code in pyagenity/store/base_store.py
270
271
272
async def arelease(self) -> None:
    """Clean up any resources used by the store (override in subclasses if needed)."""
    raise NotImplementedError
asearch abstractmethod async
asearch(config, query, memory_type=None, category=None, limit=10, score_threshold=None, filters=None, retrieval_strategy=RetrievalStrategy.SIMILARITY, distance_metric=DistanceMetric.COSINE, max_tokens=4000, **kwargs)

Search memories by content similarity.

Parameters:

Name Type Description Default
query str

Search query

required
user_id

Filter by user

required
agent_id

Filter by agent

required
memory_type MemoryType | None

Filter by memory type

None
category str | None

Filter by category

None
limit int

Maximum results

10
score_threshold float | None

Minimum similarity score

None
filters dict[str, Any] | None

Additional filters

None
**kwargs

Store-specific parameters

{}

Returns:

Type Description
list[MemorySearchResult]

List of matching memories

Source code in pyagenity/store/base_store.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
@abstractmethod
async def asearch(
    self,
    config: dict[str, Any],
    query: str,
    memory_type: MemoryType | None = None,
    category: str | None = None,
    limit: int = 10,
    score_threshold: float | None = None,
    filters: dict[str, Any] | None = None,
    retrieval_strategy: RetrievalStrategy = RetrievalStrategy.SIMILARITY,
    distance_metric: DistanceMetric = DistanceMetric.COSINE,
    max_tokens: int = 4000,
    **kwargs,
) -> list[MemorySearchResult]:
    """
    Search memories by content similarity.

    Args:
        query: Search query
        user_id: Filter by user
        agent_id: Filter by agent
        memory_type: Filter by memory type
        category: Filter by category
        limit: Maximum results
        score_threshold: Minimum similarity score
        filters: Additional filters
        **kwargs: Store-specific parameters

    Returns:
        List of matching memories
    """
    raise NotImplementedError
asetup async
asetup()

Asynchronous setup method for checkpointer.

Returns:

Name Type Description
Any Any

Implementation-defined setup result.

Source code in pyagenity/store/base_store.py
48
49
50
51
52
53
54
55
async def asetup(self) -> Any:
    """
    Asynchronous setup method for checkpointer.

    Returns:
        Any: Implementation-defined setup result.
    """
    raise NotImplementedError
astore abstractmethod async
astore(config, content, memory_type=MemoryType.EPISODIC, category='general', metadata=None, **kwargs)

Add a new memory.

Parameters:

Name Type Description Default
content str | Message

The memory content

required
user_id

User identifier

required
agent_id

Agent identifier

required
memory_type MemoryType

Type of memory (episodic, semantic, etc.)

EPISODIC
category str

Memory category for organization

'general'
metadata dict[str, Any] | None

Additional metadata

None
**kwargs

Store-specific parameters

{}

Returns:

Type Description
str

Memory ID

Source code in pyagenity/store/base_store.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
@abstractmethod
async def astore(
    self,
    config: dict[str, Any],
    content: str | Message,
    memory_type: MemoryType = MemoryType.EPISODIC,
    category: str = "general",
    metadata: dict[str, Any] | None = None,
    **kwargs,
) -> str:
    """
    Add a new memory.

    Args:
        content: The memory content
        user_id: User identifier
        agent_id: Agent identifier
        memory_type: Type of memory (episodic, semantic, etc.)
        category: Memory category for organization
        metadata: Additional metadata
        **kwargs: Store-specific parameters

    Returns:
        Memory ID
    """
    raise NotImplementedError
aupdate abstractmethod async
aupdate(config, memory_id, content, metadata=None, **kwargs)

Update an existing memory.

Parameters:

Name Type Description Default
memory_id str

ID of memory to update

required
content str | Message

New content (optional)

required
metadata dict[str, Any] | None

New/additional metadata (optional)

None
**kwargs

Store-specific parameters

{}
Source code in pyagenity/store/base_store.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
@abstractmethod
async def aupdate(
    self,
    config: dict[str, Any],
    memory_id: str,
    content: str | Message,
    metadata: dict[str, Any] | None = None,
    **kwargs,
) -> Any:
    """
    Update an existing memory.

    Args:
        memory_id: ID of memory to update
        content: New content (optional)
        metadata: New/additional metadata (optional)
        **kwargs: Store-specific parameters
    """
    raise NotImplementedError
delete
delete(config, memory_id, **kwargs)

Synchronous wrapper for adelete that runs the async implementation.

Source code in pyagenity/store/base_store.py
247
248
249
def delete(self, config: dict[str, Any], memory_id: str, **kwargs) -> None:
    """Synchronous wrapper for `adelete` that runs the async implementation."""
    return run_coroutine(self.adelete(config, memory_id, **kwargs))
forget_memory
forget_memory(config, **kwargs)

Delete a memory by for a user or agent.

Source code in pyagenity/store/base_store.py
260
261
262
263
264
265
266
def forget_memory(
    self,
    config: dict[str, Any],
    **kwargs,
) -> Any:
    """Delete a memory by for a user or agent."""
    return run_coroutine(self.aforget_memory(config, **kwargs))
get
get(config, memory_id, **kwargs)

Synchronous wrapper for aget that runs the async implementation.

Source code in pyagenity/store/base_store.py
193
194
195
def get(self, config: dict[str, Any], memory_id: str, **kwargs) -> MemorySearchResult | None:
    """Synchronous wrapper for `aget` that runs the async implementation."""
    return run_coroutine(self.aget(config, memory_id, **kwargs))
get_all
get_all(config, limit=100, **kwargs)

Synchronous wrapper for aget that runs the async implementation.

Source code in pyagenity/store/base_store.py
197
198
199
200
201
202
203
204
def get_all(
    self,
    config: dict[str, Any],
    limit: int = 100,
    **kwargs,
) -> list[MemorySearchResult]:
    """Synchronous wrapper for `aget` that runs the async implementation."""
    return run_coroutine(self.aget_all(config, limit=limit, **kwargs))
release
release()

Clean up any resources used by the store (override in subclasses if needed).

Source code in pyagenity/store/base_store.py
274
275
276
def release(self) -> None:
    """Clean up any resources used by the store (override in subclasses if needed)."""
    return run_coroutine(self.arelease())
search
search(config, query, memory_type=None, category=None, limit=10, score_threshold=None, filters=None, retrieval_strategy=RetrievalStrategy.SIMILARITY, distance_metric=DistanceMetric.COSINE, max_tokens=4000, **kwargs)

Synchronous wrapper for asearch that runs the async implementation.

Source code in pyagenity/store/base_store.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def search(
    self,
    config: dict[str, Any],
    query: str,
    memory_type: MemoryType | None = None,
    category: str | None = None,
    limit: int = 10,
    score_threshold: float | None = None,
    filters: dict[str, Any] | None = None,
    retrieval_strategy: RetrievalStrategy = RetrievalStrategy.SIMILARITY,
    distance_metric: DistanceMetric = DistanceMetric.COSINE,
    max_tokens: int = 4000,
    **kwargs,
) -> list[MemorySearchResult]:
    """Synchronous wrapper for `asearch` that runs the async implementation."""
    return run_coroutine(
        self.asearch(
            config,
            query,
            memory_type=memory_type,
            category=category,
            limit=limit,
            score_threshold=score_threshold,
            filters=filters,
            retrieval_strategy=retrieval_strategy,
            distance_metric=distance_metric,
            max_tokens=max_tokens,
            **kwargs,
        )
    )
setup
setup()

Synchronous setup method for checkpointer.

Returns:

Name Type Description
Any Any

Implementation-defined setup result.

Source code in pyagenity/store/base_store.py
39
40
41
42
43
44
45
46
def setup(self) -> Any:
    """
    Synchronous setup method for checkpointer.

    Returns:
        Any: Implementation-defined setup result.
    """
    return run_coroutine(self.asetup())
store
store(config, content, memory_type=MemoryType.EPISODIC, category='general', metadata=None, **kwargs)

Synchronous wrapper for astore that runs the async implementation.

Source code in pyagenity/store/base_store.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def store(
    self,
    config: dict[str, Any],
    content: str | Message,
    memory_type: MemoryType = MemoryType.EPISODIC,
    category: str = "general",
    metadata: dict[str, Any] | None = None,
    **kwargs,
) -> str:
    """Synchronous wrapper for `astore` that runs the async implementation."""
    return run_coroutine(
        self.astore(
            config,
            content,
            memory_type=memory_type,
            category=category,
            metadata=metadata,
            **kwargs,
        )
    )
update
update(config, memory_id, content, metadata=None, **kwargs)

Synchronous wrapper for aupdate that runs the async implementation.

Source code in pyagenity/store/base_store.py
226
227
228
229
230
231
232
233
234
235
def update(
    self,
    config: dict[str, Any],
    memory_id: str,
    content: str | Message,
    metadata: dict[str, Any] | None = None,
    **kwargs,
) -> Any:
    """Synchronous wrapper for `aupdate` that runs the async implementation."""
    return run_coroutine(self.aupdate(config, memory_id, content, metadata=metadata, **kwargs))
Functions
embedding

Modules:

Name Description
base_embedding
openai_embedding

Classes:

Name Description
BaseEmbedding
OpenAIEmbedding
Attributes
__all__ module-attribute
__all__ = ['BaseEmbedding', 'OpenAIEmbedding']
Classes
BaseEmbedding

Bases: ABC

Methods:

Name Description
aembed

Generate embedding for a single text.

aembed_batch

Generate embeddings for a list of texts.

embed

Synchronous wrapper for aembed that runs the async implementation.

embed_batch

Synchronous wrapper for aembed_batch that runs the async implementation.

Attributes:

Name Type Description
dimension int

Synchronous wrapper for that runs the async implementation.

Source code in pyagenity/store/embedding/base_embedding.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class BaseEmbedding(ABC):
    def embed_batch(self, texts: list[str]) -> list[list[float]]:
        """Synchronous wrapper for `aembed_batch` that runs the async implementation."""
        return run_coroutine(self.aembed_batch(texts))

    @abstractmethod
    async def aembed_batch(self, texts: list[str]) -> list[list[float]]:
        """Generate embeddings for a list of texts."""
        # pragma: no cover

    def embed(self, text: str) -> list[float]:
        """Synchronous wrapper for `aembed` that runs the async implementation."""
        return run_coroutine(self.aembed(text))

    @abstractmethod
    async def aembed(self, text: str) -> list[float]:
        """Generate embedding for a single text."""
        raise NotImplementedError

    @property
    @abstractmethod
    def dimension(self) -> int:
        """Synchronous wrapper for that runs the async implementation."""
        raise NotImplementedError
Attributes
dimension abstractmethod property
dimension

Synchronous wrapper for that runs the async implementation.

Functions
aembed abstractmethod async
aembed(text)

Generate embedding for a single text.

Source code in pyagenity/store/embedding/base_embedding.py
20
21
22
23
@abstractmethod
async def aembed(self, text: str) -> list[float]:
    """Generate embedding for a single text."""
    raise NotImplementedError
aembed_batch abstractmethod async
aembed_batch(texts)

Generate embeddings for a list of texts.

Source code in pyagenity/store/embedding/base_embedding.py
11
12
13
@abstractmethod
async def aembed_batch(self, texts: list[str]) -> list[list[float]]:
    """Generate embeddings for a list of texts."""
embed
embed(text)

Synchronous wrapper for aembed that runs the async implementation.

Source code in pyagenity/store/embedding/base_embedding.py
16
17
18
def embed(self, text: str) -> list[float]:
    """Synchronous wrapper for `aembed` that runs the async implementation."""
    return run_coroutine(self.aembed(text))
embed_batch
embed_batch(texts)

Synchronous wrapper for aembed_batch that runs the async implementation.

Source code in pyagenity/store/embedding/base_embedding.py
7
8
9
def embed_batch(self, texts: list[str]) -> list[list[float]]:
    """Synchronous wrapper for `aembed_batch` that runs the async implementation."""
    return run_coroutine(self.aembed_batch(texts))
OpenAIEmbedding

Bases: BaseEmbedding

Methods:

Name Description
__init__
aembed
aembed_batch
embed

Synchronous wrapper for aembed that runs the async implementation.

embed_batch

Synchronous wrapper for aembed_batch that runs the async implementation.

Attributes:

Name Type Description
api_key
client
dimension int
model
Source code in pyagenity/store/embedding/openai_embedding.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class OpenAIEmbedding(BaseEmbedding):
    def __init__(
        self,
        model: str = "text-embedding-3-small",
        OPENAI_API_KEY: str | None = None,
    ) -> None:
        if not HAS_OPENAI:
            raise ImportError(
                "The 'openai' package is required for OpenAIEmbedding. "
                "Please install it via 'pip install openai'."
            )
        self.model = model
        if OPENAI_API_KEY:
            self.api_key = OPENAI_API_KEY
        elif "OPENAI_API_KEY" in os.environ:
            self.api_key = os.environ["OPENAI_API_KEY"]
        else:
            raise ValueError(
                "OpenAI API key must be provided via parameter or OPENAI_API_KEY env var"
            )

        self.client = AsyncOpenAI(
            api_key=self.api_key,
        )

    async def aembed_batch(self, texts: list[str]) -> list[list[float]]:
        try:
            response = await self.client.embeddings.create(
                input=texts,
                model=self.model,
            )
            return [data.embedding for data in response.data]
        except OpenAIError as e:
            raise RuntimeError(f"OpenAI API error: {e}") from e

    async def aembed(self, text: str) -> list[float]:
        try:
            response = await self.client.embeddings.create(
                input=text,
                model=self.model,
            )
            return response.data[0].embedding if response.data else []
        except OpenAIError as e:
            raise RuntimeError(f"OpenAI API error: {e}") from e

    @property
    def dimension(self) -> int:
        model_dimensions = {
            "text-embedding-3-small": 1536,
            "text-embedding-3-large": 1536,
            "text-embedding-3-xl": 1536,
            "text-embedding-4-base": 8192,
            "text-embedding-4-large": 8192,
        }
        if self.model in model_dimensions:
            return model_dimensions[self.model]
        raise ValueError(f"Unknown model '{self.model}'. Cannot determine dimension.")
Attributes
api_key instance-attribute
api_key = OPENAI_API_KEY
client instance-attribute
client = AsyncOpenAI(api_key=api_key)
dimension property
dimension
model instance-attribute
model = model
Functions
__init__
__init__(model='text-embedding-3-small', OPENAI_API_KEY=None)
Source code in pyagenity/store/embedding/openai_embedding.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def __init__(
    self,
    model: str = "text-embedding-3-small",
    OPENAI_API_KEY: str | None = None,
) -> None:
    if not HAS_OPENAI:
        raise ImportError(
            "The 'openai' package is required for OpenAIEmbedding. "
            "Please install it via 'pip install openai'."
        )
    self.model = model
    if OPENAI_API_KEY:
        self.api_key = OPENAI_API_KEY
    elif "OPENAI_API_KEY" in os.environ:
        self.api_key = os.environ["OPENAI_API_KEY"]
    else:
        raise ValueError(
            "OpenAI API key must be provided via parameter or OPENAI_API_KEY env var"
        )

    self.client = AsyncOpenAI(
        api_key=self.api_key,
    )
aembed async
aembed(text)
Source code in pyagenity/store/embedding/openai_embedding.py
54
55
56
57
58
59
60
61
62
async def aembed(self, text: str) -> list[float]:
    try:
        response = await self.client.embeddings.create(
            input=text,
            model=self.model,
        )
        return response.data[0].embedding if response.data else []
    except OpenAIError as e:
        raise RuntimeError(f"OpenAI API error: {e}") from e
aembed_batch async
aembed_batch(texts)
Source code in pyagenity/store/embedding/openai_embedding.py
44
45
46
47
48
49
50
51
52
async def aembed_batch(self, texts: list[str]) -> list[list[float]]:
    try:
        response = await self.client.embeddings.create(
            input=texts,
            model=self.model,
        )
        return [data.embedding for data in response.data]
    except OpenAIError as e:
        raise RuntimeError(f"OpenAI API error: {e}") from e
embed
embed(text)

Synchronous wrapper for aembed that runs the async implementation.

Source code in pyagenity/store/embedding/base_embedding.py
16
17
18
def embed(self, text: str) -> list[float]:
    """Synchronous wrapper for `aembed` that runs the async implementation."""
    return run_coroutine(self.aembed(text))
embed_batch
embed_batch(texts)

Synchronous wrapper for aembed_batch that runs the async implementation.

Source code in pyagenity/store/embedding/base_embedding.py
7
8
9
def embed_batch(self, texts: list[str]) -> list[list[float]]:
    """Synchronous wrapper for `aembed_batch` that runs the async implementation."""
    return run_coroutine(self.aembed_batch(texts))
Modules
base_embedding

Classes:

Name Description
BaseEmbedding
Classes
BaseEmbedding

Bases: ABC

Methods:

Name Description
aembed

Generate embedding for a single text.

aembed_batch

Generate embeddings for a list of texts.

embed

Synchronous wrapper for aembed that runs the async implementation.

embed_batch

Synchronous wrapper for aembed_batch that runs the async implementation.

Attributes:

Name Type Description
dimension int

Synchronous wrapper for that runs the async implementation.

Source code in pyagenity/store/embedding/base_embedding.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class BaseEmbedding(ABC):
    def embed_batch(self, texts: list[str]) -> list[list[float]]:
        """Synchronous wrapper for `aembed_batch` that runs the async implementation."""
        return run_coroutine(self.aembed_batch(texts))

    @abstractmethod
    async def aembed_batch(self, texts: list[str]) -> list[list[float]]:
        """Generate embeddings for a list of texts."""
        # pragma: no cover

    def embed(self, text: str) -> list[float]:
        """Synchronous wrapper for `aembed` that runs the async implementation."""
        return run_coroutine(self.aembed(text))

    @abstractmethod
    async def aembed(self, text: str) -> list[float]:
        """Generate embedding for a single text."""
        raise NotImplementedError

    @property
    @abstractmethod
    def dimension(self) -> int:
        """Synchronous wrapper for that runs the async implementation."""
        raise NotImplementedError
Attributes
dimension abstractmethod property
dimension

Synchronous wrapper for that runs the async implementation.

Functions
aembed abstractmethod async
aembed(text)

Generate embedding for a single text.

Source code in pyagenity/store/embedding/base_embedding.py
20
21
22
23
@abstractmethod
async def aembed(self, text: str) -> list[float]:
    """Generate embedding for a single text."""
    raise NotImplementedError
aembed_batch abstractmethod async
aembed_batch(texts)

Generate embeddings for a list of texts.

Source code in pyagenity/store/embedding/base_embedding.py
11
12
13
@abstractmethod
async def aembed_batch(self, texts: list[str]) -> list[list[float]]:
    """Generate embeddings for a list of texts."""
embed
embed(text)

Synchronous wrapper for aembed that runs the async implementation.

Source code in pyagenity/store/embedding/base_embedding.py
16
17
18
def embed(self, text: str) -> list[float]:
    """Synchronous wrapper for `aembed` that runs the async implementation."""
    return run_coroutine(self.aembed(text))
embed_batch
embed_batch(texts)

Synchronous wrapper for aembed_batch that runs the async implementation.

Source code in pyagenity/store/embedding/base_embedding.py
7
8
9
def embed_batch(self, texts: list[str]) -> list[list[float]]:
    """Synchronous wrapper for `aembed_batch` that runs the async implementation."""
    return run_coroutine(self.aembed_batch(texts))
Functions
openai_embedding

Classes:

Name Description
OpenAIEmbedding

Attributes:

Name Type Description
HAS_OPENAI
Attributes
HAS_OPENAI module-attribute
HAS_OPENAI = True
Classes
OpenAIEmbedding

Bases: BaseEmbedding

Methods:

Name Description
__init__
aembed
aembed_batch
embed

Synchronous wrapper for aembed that runs the async implementation.

embed_batch

Synchronous wrapper for aembed_batch that runs the async implementation.

Attributes:

Name Type Description
api_key
client
dimension int
model
Source code in pyagenity/store/embedding/openai_embedding.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class OpenAIEmbedding(BaseEmbedding):
    def __init__(
        self,
        model: str = "text-embedding-3-small",
        OPENAI_API_KEY: str | None = None,
    ) -> None:
        if not HAS_OPENAI:
            raise ImportError(
                "The 'openai' package is required for OpenAIEmbedding. "
                "Please install it via 'pip install openai'."
            )
        self.model = model
        if OPENAI_API_KEY:
            self.api_key = OPENAI_API_KEY
        elif "OPENAI_API_KEY" in os.environ:
            self.api_key = os.environ["OPENAI_API_KEY"]
        else:
            raise ValueError(
                "OpenAI API key must be provided via parameter or OPENAI_API_KEY env var"
            )

        self.client = AsyncOpenAI(
            api_key=self.api_key,
        )

    async def aembed_batch(self, texts: list[str]) -> list[list[float]]:
        try:
            response = await self.client.embeddings.create(
                input=texts,
                model=self.model,
            )
            return [data.embedding for data in response.data]
        except OpenAIError as e:
            raise RuntimeError(f"OpenAI API error: {e}") from e

    async def aembed(self, text: str) -> list[float]:
        try:
            response = await self.client.embeddings.create(
                input=text,
                model=self.model,
            )
            return response.data[0].embedding if response.data else []
        except OpenAIError as e:
            raise RuntimeError(f"OpenAI API error: {e}") from e

    @property
    def dimension(self) -> int:
        model_dimensions = {
            "text-embedding-3-small": 1536,
            "text-embedding-3-large": 1536,
            "text-embedding-3-xl": 1536,
            "text-embedding-4-base": 8192,
            "text-embedding-4-large": 8192,
        }
        if self.model in model_dimensions:
            return model_dimensions[self.model]
        raise ValueError(f"Unknown model '{self.model}'. Cannot determine dimension.")
Attributes
api_key instance-attribute
api_key = OPENAI_API_KEY
client instance-attribute
client = AsyncOpenAI(api_key=api_key)
dimension property
dimension
model instance-attribute
model = model
Functions
__init__
__init__(model='text-embedding-3-small', OPENAI_API_KEY=None)
Source code in pyagenity/store/embedding/openai_embedding.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def __init__(
    self,
    model: str = "text-embedding-3-small",
    OPENAI_API_KEY: str | None = None,
) -> None:
    if not HAS_OPENAI:
        raise ImportError(
            "The 'openai' package is required for OpenAIEmbedding. "
            "Please install it via 'pip install openai'."
        )
    self.model = model
    if OPENAI_API_KEY:
        self.api_key = OPENAI_API_KEY
    elif "OPENAI_API_KEY" in os.environ:
        self.api_key = os.environ["OPENAI_API_KEY"]
    else:
        raise ValueError(
            "OpenAI API key must be provided via parameter or OPENAI_API_KEY env var"
        )

    self.client = AsyncOpenAI(
        api_key=self.api_key,
    )
aembed async
aembed(text)
Source code in pyagenity/store/embedding/openai_embedding.py
54
55
56
57
58
59
60
61
62
async def aembed(self, text: str) -> list[float]:
    try:
        response = await self.client.embeddings.create(
            input=text,
            model=self.model,
        )
        return response.data[0].embedding if response.data else []
    except OpenAIError as e:
        raise RuntimeError(f"OpenAI API error: {e}") from e
aembed_batch async
aembed_batch(texts)
Source code in pyagenity/store/embedding/openai_embedding.py
44
45
46
47
48
49
50
51
52
async def aembed_batch(self, texts: list[str]) -> list[list[float]]:
    try:
        response = await self.client.embeddings.create(
            input=texts,
            model=self.model,
        )
        return [data.embedding for data in response.data]
    except OpenAIError as e:
        raise RuntimeError(f"OpenAI API error: {e}") from e
embed
embed(text)

Synchronous wrapper for aembed that runs the async implementation.

Source code in pyagenity/store/embedding/base_embedding.py
16
17
18
def embed(self, text: str) -> list[float]:
    """Synchronous wrapper for `aembed` that runs the async implementation."""
    return run_coroutine(self.aembed(text))
embed_batch
embed_batch(texts)

Synchronous wrapper for aembed_batch that runs the async implementation.

Source code in pyagenity/store/embedding/base_embedding.py
7
8
9
def embed_batch(self, texts: list[str]) -> list[list[float]]:
    """Synchronous wrapper for `aembed_batch` that runs the async implementation."""
    return run_coroutine(self.aembed_batch(texts))
mem0_store

Mem0 Long-Term Memory Store

Async-first implementation of :class:BaseStore that uses the mem0 library as a managed long-term memory layer. In PyAgenity we treat the graph state as short-term (ephemeral per run / session) memory and a store implementation as long-term, durable memory. This module wires Mem0 so that:

  • astore / asearch / etc. map to Mem0's add, search, get_all, update, delete.
  • We maintain a generated UUID (framework memory id) separate from the Mem0 internal id.
  • Metadata is enriched to retain memory type, category, timestamps and app scoping.
  • The public async methods satisfy the :class:BaseStore contract (astore, abatch_store, asearch, aget, aupdate, adelete, aforget_memory and arelease).
Design notes:

Mem0 (>= 0.2.x / 2025 spec) still exposes synchronous Python APIs. We off-load blocking calls to a thread executor to keep the interface awaitable. Where Mem0 does not support an operation directly (e.g. fetch by custom memory id) we fallback to scanning get_all for the user. For batch insertion we parallelise Add operations with gather while bounding concurrency (simple semaphore) to avoid thread explosion.

The store interprets the supplied config mapping passed to every method as: {"user_id": str | None, "thread_id": str | None, "app_id": str | None}. thread_id is stored into metadata under agent_id for backward compatibility with earlier implementations where agent_id served a similar role.

Prerequisite: install mem0.

pip install mem0ai
Optional vector DB / embedder / llm configuration should be supplied through Mem0's native configuration structure (see upstream docs - memory configuration, vector store configuration). You can also use helper factory function create_mem0_store_with_qdrant for quick Qdrant backing.

Classes:

Name Description
Mem0Store

Mem0 implementation of long-term memory.

Functions:

Name Description
create_mem0_store

Factory for a basic Mem0 long-term store.

create_mem0_store_with_qdrant

Factory producing a Mem0Store configured for Qdrant backing.

Attributes:

Name Type Description
logger
Attributes
logger module-attribute
logger = getLogger(__name__)
Classes
Mem0Store

Bases: BaseStore

Mem0 implementation of long-term memory.

Primary responsibilities: * Persist memories (episodic by default) across graph invocations * Retrieve semantically similar memories to augment state * Provide CRUD lifecycle aligned with BaseStore async API

Unlike in-memory state, these memories survive process restarts as they are managed by Mem0's configured vector / persistence layer.

Methods:

Name Description
__init__
adelete
aforget_memory
aget
aget_all
arelease
asearch
asetup

Asynchronous setup method for checkpointer.

astore
aupdate
delete

Synchronous wrapper for adelete that runs the async implementation.

forget_memory

Delete a memory by for a user or agent.

generate_framework_id
get

Synchronous wrapper for aget that runs the async implementation.

get_all

Synchronous wrapper for aget that runs the async implementation.

release

Clean up any resources used by the store (override in subclasses if needed).

search

Synchronous wrapper for asearch that runs the async implementation.

setup

Synchronous setup method for checkpointer.

store

Synchronous wrapper for astore that runs the async implementation.

update

Synchronous wrapper for aupdate that runs the async implementation.

Attributes:

Name Type Description
app_id
config
Source code in pyagenity/store/mem0_store.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
class Mem0Store(BaseStore):
    """Mem0 implementation of long-term memory.

    Primary responsibilities:
    * Persist memories (episodic by default) across graph invocations
    * Retrieve semantically similar memories to augment state
    * Provide CRUD lifecycle aligned with ``BaseStore`` async API

    Unlike in-memory state, these memories survive process restarts as they are
    managed by Mem0's configured vector / persistence layer.
    """

    def __init__(
        self,
        config: MemoryConfig | dict,
        app_id: str | None = None,
        **kwargs: Any,
    ) -> None:
        super().__init__()
        self.config = config
        self.app_id = app_id or "pyagenity_app"
        self._client = None  # Lazy initialization

        logger.info(
            "Initialized Mem0Store (long-term) app=%s",
            self.app_id,
        )

    async def _get_client(self) -> AsyncMemory:
        """Lazy initialization of AsyncMemory client."""
        if self._client is None:
            try:
                # Prefer explicit config via Memory.from_config when supplied; fallback to defaults
                if isinstance(self.config, dict):
                    self._client = await AsyncMemory.from_config(self.config)
                elif isinstance(self.config, MemoryConfig):
                    self._client = AsyncMemory(config=self.config)
                else:
                    self._client = AsyncMemory()
            except Exception as e:  # pragma: no cover - defensive
                logger.error(f"Failed to initialize Mem0 client: {e}")
                raise
        return self._client

    # ---------------------------------------------------------------------
    # Internal helpers
    # ---------------------------------------------------------------------

    def _extract_ids(self, config: dict[str, Any]) -> tuple[str, str | None, str | None]:
        """Extract user_id, thread_id, app_id from per-call config with fallbacks."""
        user_id = config.get("user_id")
        thread_id = config.get("thread_id")
        app_id = config.get("app_id") or self.app_id

        # if user id and thread id are not provided, we cannot proceed
        if not user_id:
            raise ValueError("user_id must be provided in config")

        if not thread_id:
            raise ValueError("thread_id must be provided in config")

        return user_id, thread_id, app_id

    def _create_result(
        self,
        raw: dict[str, Any],
        user_id: str,
    ) -> MemorySearchResult:
        # check user id belongs to the user
        if raw.get("user_id") != user_id:
            raise ValueError("Memory user_id does not match the requested user_id")

        metadata = raw.get("metadata", {}) or {}
        # Ensure memory_type enum mapping
        memory_type_val = metadata.get("memory_type", MemoryType.EPISODIC.value)
        try:
            memory_type = MemoryType(memory_type_val)
        except ValueError:
            memory_type = MemoryType.EPISODIC

        return MemorySearchResult(
            id=metadata.get("memory_id", str(raw.get("id", uuid4()))),
            content=raw.get("memory") or raw.get("data", ""),
            score=float(raw.get("score", 0.0) or 0.0),
            memory_type=memory_type,
            metadata=metadata,
            user_id=user_id,
            thread_id=metadata.get("run_id"),
        )

    def _iter_results(self, response: Any) -> Iterable[dict[str, Any]]:
        if isinstance(response, list):
            for item in response:
                if isinstance(item, dict):
                    yield item
        elif isinstance(response, dict) and "results" in response:
            for item in response["results"]:
                if isinstance(item, dict):
                    yield item
        else:  # pragma: no cover
            logger.debug("Unexpected Mem0 response type: %s", type(response))

    async def generate_framework_id(self) -> str:
        generated_id = InjectQ.get_instance().try_get("generated_id", str(uuid4()))
        if isinstance(generated_id, Awaitable):
            generated_id = await generated_id
        return generated_id

    # ------------------------------------------------------------------
    # BaseStore required async operations
    # ------------------------------------------------------------------

    async def astore(
        self,
        config: dict[str, Any],
        content: str | Message,
        memory_type: MemoryType = MemoryType.EPISODIC,
        category: str = "general",
        metadata: dict[str, Any] | None = None,
        **kwargs: Any,
    ) -> Any:
        text = content.text() if isinstance(content, Message) else str(content)
        if not text.strip():
            raise ValueError("Content cannot be empty")

        user_id, thread_id, app_id = self._extract_ids(config)

        mem_meta = {
            "memory_type": memory_type.value,
            "category": category,
            "created_at": datetime.now().isoformat(),
            **(metadata or {}),
        }

        client = await self._get_client()
        result = await client.add(  # type: ignore
            messages=[{"role": "user", "content": text}],
            user_id=user_id,
            agent_id=app_id,
            run_id=thread_id,
            metadata=mem_meta,
        )

        logger.debug("Stored memory for user=%s thread=%s id=%s", user_id, thread_id, result)

        return result

    async def asearch(
        self,
        config: dict[str, Any],
        query: str,
        memory_type: MemoryType | None = None,
        category: str | None = None,
        limit: int = 10,
        score_threshold: float | None = None,
        filters: dict[str, Any] | None = None,
        retrieval_strategy=None,  # Unused for Mem0; kept for signature parity
        distance_metric=None,  # Unused
        max_tokens: int = 4000,
        **kwargs: Any,
    ) -> list[MemorySearchResult]:
        user_id, thread_id, app_id = self._extract_ids(config)

        client = await self._get_client()
        result = await client.search(  # type: ignore
            query=query,
            user_id=user_id,
            agent_id=app_id,
            limit=limit,
            filters=filters,
            threshold=score_threshold,
        )

        if "original_results" not in result:
            logger.warning("Mem0 search response missing 'original_results': %s", result)
            return []

        if "relations" in result:
            logger.warning(
                "Mem0 search response contains 'relations', which is not supported yet: %s",
                result,
            )

        out: list[MemorySearchResult] = [
            self._create_result(raw, user_id) for raw in result["original_results"]
        ]

        logger.debug(
            "Searched memories for user=%s thread=%s query=%s found=%d",
            user_id,
            thread_id,
            query,
            len(out),
        )
        return out

    async def aget(
        self,
        config: dict[str, Any],
        memory_id: str,
        **kwargs: Any,
    ) -> MemorySearchResult | None:
        user_id, _, _ = self._extract_ids(config)
        # If we stored mapping use that user id instead (authoritative)

        client = await self._get_client()
        result = await client.get(  # type: ignore
            memory_id=memory_id,
        )

        return self._create_result(result, user_id) if result else None

    async def aget_all(
        self,
        config: dict[str, Any],
        limit: int = 100,
        **kwargs: Any,
    ) -> list[MemorySearchResult]:
        user_id, thread_id, app_id = self._extract_ids(config)

        client = await self._get_client()
        result = await client.get_all(  # type: ignore
            user_id=user_id,
            agent_id=app_id,
            limit=limit,
        )

        if "results" not in result:
            logger.warning("Mem0 get_all response missing 'results': %s", result)
            return []

        if "relations" in result:
            logger.warning(
                "Mem0 get_all response contains 'relations', which is not supported yet: %s",
                result,
            )

        out: list[MemorySearchResult] = [
            self._create_result(raw, user_id) for raw in result["results"]
        ]

        logger.debug(
            "Fetched all memories for user=%s thread=%s count=%d",
            user_id,
            thread_id,
            len(out),
        )
        return out

    async def aupdate(
        self,
        config: dict[str, Any],
        memory_id: str,
        content: str | Message,
        metadata: dict[str, Any] | None = None,
        **kwargs: Any,
    ) -> Any:
        existing = await self.aget(config, memory_id)
        if not existing:
            raise ValueError(f"Memory {memory_id} not found")

        # user_id obtained for potential permission checks (not used by Mem0 update directly)

        new_text = content.text() if isinstance(content, Message) else str(content)
        updated_meta = {**(existing.metadata or {}), **(metadata or {})}
        updated_meta["updated_at"] = datetime.now().isoformat()

        client = await self._get_client()
        res = await client.update(  # type: ignore
            memory_id=existing.id,
            data=new_text,
        )

        logger.debug("Updated memory %s via recreate", memory_id)
        return res

    async def adelete(
        self,
        config: dict[str, Any],
        memory_id: str,
        **kwargs: Any,
    ) -> Any:
        user_id, _, _ = self._extract_ids(config)
        existing = await self.aget(config, memory_id)
        if not existing:
            logger.warning("Memory %s not found for deletion", memory_id)
            return {
                "deleted": False,
                "reason": "not_found",
            }

        if existing.user_id != user_id:
            raise ValueError("Cannot delete memory belonging to a different user")

        client = await self._get_client()
        res = await client.delete(  # type: ignore
            memory_id=existing.id,
        )

        logger.debug("Deleted memory %s for user %s", memory_id, user_id)
        return res

    async def aforget_memory(
        self,
        config: dict[str, Any],
        **kwargs: Any,
    ) -> Any:
        # Delete all memories for a user
        user_id, _, _ = self._extract_ids(config)
        client = await self._get_client()
        res = await client.delete_all(user_id=user_id)  # type: ignore
        logger.debug("Forgot all memories for user %s", user_id)
        return res

    async def arelease(self) -> None:
        logger.info("Mem0Store released resources")
Attributes
app_id instance-attribute
app_id = app_id or 'pyagenity_app'
config instance-attribute
config = config
Functions
__init__
__init__(config, app_id=None, **kwargs)
Source code in pyagenity/store/mem0_store.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def __init__(
    self,
    config: MemoryConfig | dict,
    app_id: str | None = None,
    **kwargs: Any,
) -> None:
    super().__init__()
    self.config = config
    self.app_id = app_id or "pyagenity_app"
    self._client = None  # Lazy initialization

    logger.info(
        "Initialized Mem0Store (long-term) app=%s",
        self.app_id,
    )
adelete async
adelete(config, memory_id, **kwargs)
Source code in pyagenity/store/mem0_store.py
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
async def adelete(
    self,
    config: dict[str, Any],
    memory_id: str,
    **kwargs: Any,
) -> Any:
    user_id, _, _ = self._extract_ids(config)
    existing = await self.aget(config, memory_id)
    if not existing:
        logger.warning("Memory %s not found for deletion", memory_id)
        return {
            "deleted": False,
            "reason": "not_found",
        }

    if existing.user_id != user_id:
        raise ValueError("Cannot delete memory belonging to a different user")

    client = await self._get_client()
    res = await client.delete(  # type: ignore
        memory_id=existing.id,
    )

    logger.debug("Deleted memory %s for user %s", memory_id, user_id)
    return res
aforget_memory async
aforget_memory(config, **kwargs)
Source code in pyagenity/store/mem0_store.py
363
364
365
366
367
368
369
370
371
372
373
async def aforget_memory(
    self,
    config: dict[str, Any],
    **kwargs: Any,
) -> Any:
    # Delete all memories for a user
    user_id, _, _ = self._extract_ids(config)
    client = await self._get_client()
    res = await client.delete_all(user_id=user_id)  # type: ignore
    logger.debug("Forgot all memories for user %s", user_id)
    return res
aget async
aget(config, memory_id, **kwargs)
Source code in pyagenity/store/mem0_store.py
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
async def aget(
    self,
    config: dict[str, Any],
    memory_id: str,
    **kwargs: Any,
) -> MemorySearchResult | None:
    user_id, _, _ = self._extract_ids(config)
    # If we stored mapping use that user id instead (authoritative)

    client = await self._get_client()
    result = await client.get(  # type: ignore
        memory_id=memory_id,
    )

    return self._create_result(result, user_id) if result else None
aget_all async
aget_all(config, limit=100, **kwargs)
Source code in pyagenity/store/mem0_store.py
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
async def aget_all(
    self,
    config: dict[str, Any],
    limit: int = 100,
    **kwargs: Any,
) -> list[MemorySearchResult]:
    user_id, thread_id, app_id = self._extract_ids(config)

    client = await self._get_client()
    result = await client.get_all(  # type: ignore
        user_id=user_id,
        agent_id=app_id,
        limit=limit,
    )

    if "results" not in result:
        logger.warning("Mem0 get_all response missing 'results': %s", result)
        return []

    if "relations" in result:
        logger.warning(
            "Mem0 get_all response contains 'relations', which is not supported yet: %s",
            result,
        )

    out: list[MemorySearchResult] = [
        self._create_result(raw, user_id) for raw in result["results"]
    ]

    logger.debug(
        "Fetched all memories for user=%s thread=%s count=%d",
        user_id,
        thread_id,
        len(out),
    )
    return out
arelease async
arelease()
Source code in pyagenity/store/mem0_store.py
375
376
async def arelease(self) -> None:
    logger.info("Mem0Store released resources")
asearch async
asearch(config, query, memory_type=None, category=None, limit=10, score_threshold=None, filters=None, retrieval_strategy=None, distance_metric=None, max_tokens=4000, **kwargs)
Source code in pyagenity/store/mem0_store.py
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
async def asearch(
    self,
    config: dict[str, Any],
    query: str,
    memory_type: MemoryType | None = None,
    category: str | None = None,
    limit: int = 10,
    score_threshold: float | None = None,
    filters: dict[str, Any] | None = None,
    retrieval_strategy=None,  # Unused for Mem0; kept for signature parity
    distance_metric=None,  # Unused
    max_tokens: int = 4000,
    **kwargs: Any,
) -> list[MemorySearchResult]:
    user_id, thread_id, app_id = self._extract_ids(config)

    client = await self._get_client()
    result = await client.search(  # type: ignore
        query=query,
        user_id=user_id,
        agent_id=app_id,
        limit=limit,
        filters=filters,
        threshold=score_threshold,
    )

    if "original_results" not in result:
        logger.warning("Mem0 search response missing 'original_results': %s", result)
        return []

    if "relations" in result:
        logger.warning(
            "Mem0 search response contains 'relations', which is not supported yet: %s",
            result,
        )

    out: list[MemorySearchResult] = [
        self._create_result(raw, user_id) for raw in result["original_results"]
    ]

    logger.debug(
        "Searched memories for user=%s thread=%s query=%s found=%d",
        user_id,
        thread_id,
        query,
        len(out),
    )
    return out
asetup async
asetup()

Asynchronous setup method for checkpointer.

Returns:

Name Type Description
Any Any

Implementation-defined setup result.

Source code in pyagenity/store/base_store.py
48
49
50
51
52
53
54
55
async def asetup(self) -> Any:
    """
    Asynchronous setup method for checkpointer.

    Returns:
        Any: Implementation-defined setup result.
    """
    raise NotImplementedError
astore async
astore(config, content, memory_type=MemoryType.EPISODIC, category='general', metadata=None, **kwargs)
Source code in pyagenity/store/mem0_store.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
async def astore(
    self,
    config: dict[str, Any],
    content: str | Message,
    memory_type: MemoryType = MemoryType.EPISODIC,
    category: str = "general",
    metadata: dict[str, Any] | None = None,
    **kwargs: Any,
) -> Any:
    text = content.text() if isinstance(content, Message) else str(content)
    if not text.strip():
        raise ValueError("Content cannot be empty")

    user_id, thread_id, app_id = self._extract_ids(config)

    mem_meta = {
        "memory_type": memory_type.value,
        "category": category,
        "created_at": datetime.now().isoformat(),
        **(metadata or {}),
    }

    client = await self._get_client()
    result = await client.add(  # type: ignore
        messages=[{"role": "user", "content": text}],
        user_id=user_id,
        agent_id=app_id,
        run_id=thread_id,
        metadata=mem_meta,
    )

    logger.debug("Stored memory for user=%s thread=%s id=%s", user_id, thread_id, result)

    return result
aupdate async
aupdate(config, memory_id, content, metadata=None, **kwargs)
Source code in pyagenity/store/mem0_store.py
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
async def aupdate(
    self,
    config: dict[str, Any],
    memory_id: str,
    content: str | Message,
    metadata: dict[str, Any] | None = None,
    **kwargs: Any,
) -> Any:
    existing = await self.aget(config, memory_id)
    if not existing:
        raise ValueError(f"Memory {memory_id} not found")

    # user_id obtained for potential permission checks (not used by Mem0 update directly)

    new_text = content.text() if isinstance(content, Message) else str(content)
    updated_meta = {**(existing.metadata or {}), **(metadata or {})}
    updated_meta["updated_at"] = datetime.now().isoformat()

    client = await self._get_client()
    res = await client.update(  # type: ignore
        memory_id=existing.id,
        data=new_text,
    )

    logger.debug("Updated memory %s via recreate", memory_id)
    return res
delete
delete(config, memory_id, **kwargs)

Synchronous wrapper for adelete that runs the async implementation.

Source code in pyagenity/store/base_store.py
247
248
249
def delete(self, config: dict[str, Any], memory_id: str, **kwargs) -> None:
    """Synchronous wrapper for `adelete` that runs the async implementation."""
    return run_coroutine(self.adelete(config, memory_id, **kwargs))
forget_memory
forget_memory(config, **kwargs)

Delete a memory by for a user or agent.

Source code in pyagenity/store/base_store.py
260
261
262
263
264
265
266
def forget_memory(
    self,
    config: dict[str, Any],
    **kwargs,
) -> Any:
    """Delete a memory by for a user or agent."""
    return run_coroutine(self.aforget_memory(config, **kwargs))
generate_framework_id async
generate_framework_id()
Source code in pyagenity/store/mem0_store.py
163
164
165
166
167
async def generate_framework_id(self) -> str:
    generated_id = InjectQ.get_instance().try_get("generated_id", str(uuid4()))
    if isinstance(generated_id, Awaitable):
        generated_id = await generated_id
    return generated_id
get
get(config, memory_id, **kwargs)

Synchronous wrapper for aget that runs the async implementation.

Source code in pyagenity/store/base_store.py
193
194
195
def get(self, config: dict[str, Any], memory_id: str, **kwargs) -> MemorySearchResult | None:
    """Synchronous wrapper for `aget` that runs the async implementation."""
    return run_coroutine(self.aget(config, memory_id, **kwargs))
get_all
get_all(config, limit=100, **kwargs)

Synchronous wrapper for aget that runs the async implementation.

Source code in pyagenity/store/base_store.py
197
198
199
200
201
202
203
204
def get_all(
    self,
    config: dict[str, Any],
    limit: int = 100,
    **kwargs,
) -> list[MemorySearchResult]:
    """Synchronous wrapper for `aget` that runs the async implementation."""
    return run_coroutine(self.aget_all(config, limit=limit, **kwargs))
release
release()

Clean up any resources used by the store (override in subclasses if needed).

Source code in pyagenity/store/base_store.py
274
275
276
def release(self) -> None:
    """Clean up any resources used by the store (override in subclasses if needed)."""
    return run_coroutine(self.arelease())
search
search(config, query, memory_type=None, category=None, limit=10, score_threshold=None, filters=None, retrieval_strategy=RetrievalStrategy.SIMILARITY, distance_metric=DistanceMetric.COSINE, max_tokens=4000, **kwargs)

Synchronous wrapper for asearch that runs the async implementation.

Source code in pyagenity/store/base_store.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def search(
    self,
    config: dict[str, Any],
    query: str,
    memory_type: MemoryType | None = None,
    category: str | None = None,
    limit: int = 10,
    score_threshold: float | None = None,
    filters: dict[str, Any] | None = None,
    retrieval_strategy: RetrievalStrategy = RetrievalStrategy.SIMILARITY,
    distance_metric: DistanceMetric = DistanceMetric.COSINE,
    max_tokens: int = 4000,
    **kwargs,
) -> list[MemorySearchResult]:
    """Synchronous wrapper for `asearch` that runs the async implementation."""
    return run_coroutine(
        self.asearch(
            config,
            query,
            memory_type=memory_type,
            category=category,
            limit=limit,
            score_threshold=score_threshold,
            filters=filters,
            retrieval_strategy=retrieval_strategy,
            distance_metric=distance_metric,
            max_tokens=max_tokens,
            **kwargs,
        )
    )
setup
setup()

Synchronous setup method for checkpointer.

Returns:

Name Type Description
Any Any

Implementation-defined setup result.

Source code in pyagenity/store/base_store.py
39
40
41
42
43
44
45
46
def setup(self) -> Any:
    """
    Synchronous setup method for checkpointer.

    Returns:
        Any: Implementation-defined setup result.
    """
    return run_coroutine(self.asetup())
store
store(config, content, memory_type=MemoryType.EPISODIC, category='general', metadata=None, **kwargs)

Synchronous wrapper for astore that runs the async implementation.

Source code in pyagenity/store/base_store.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def store(
    self,
    config: dict[str, Any],
    content: str | Message,
    memory_type: MemoryType = MemoryType.EPISODIC,
    category: str = "general",
    metadata: dict[str, Any] | None = None,
    **kwargs,
) -> str:
    """Synchronous wrapper for `astore` that runs the async implementation."""
    return run_coroutine(
        self.astore(
            config,
            content,
            memory_type=memory_type,
            category=category,
            metadata=metadata,
            **kwargs,
        )
    )
update
update(config, memory_id, content, metadata=None, **kwargs)

Synchronous wrapper for aupdate that runs the async implementation.

Source code in pyagenity/store/base_store.py
226
227
228
229
230
231
232
233
234
235
def update(
    self,
    config: dict[str, Any],
    memory_id: str,
    content: str | Message,
    metadata: dict[str, Any] | None = None,
    **kwargs,
) -> Any:
    """Synchronous wrapper for `aupdate` that runs the async implementation."""
    return run_coroutine(self.aupdate(config, memory_id, content, metadata=metadata, **kwargs))
Functions
create_mem0_store
create_mem0_store(config, user_id='default_user', thread_id=None, app_id='pyagenity_app')

Factory for a basic Mem0 long-term store.

Source code in pyagenity/store/mem0_store.py
382
383
384
385
386
387
388
389
390
391
392
393
394
def create_mem0_store(
    config: dict[str, Any],
    user_id: str = "default_user",
    thread_id: str | None = None,
    app_id: str = "pyagenity_app",
) -> Mem0Store:
    """Factory for a basic Mem0 long-term store."""
    return Mem0Store(
        config=config,
        default_user_id=user_id,
        default_thread_id=thread_id,
        app_id=app_id,
    )
create_mem0_store_with_qdrant
create_mem0_store_with_qdrant(qdrant_url, qdrant_api_key=None, collection_name='pyagenity_memories', embedding_model='text-embedding-ada-002', llm_model='gpt-4o-mini', app_id='pyagenity_app', **kwargs)

Factory producing a Mem0Store configured for Qdrant backing.

Source code in pyagenity/store/mem0_store.py
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
def create_mem0_store_with_qdrant(
    qdrant_url: str,
    qdrant_api_key: str | None = None,
    collection_name: str = "pyagenity_memories",
    embedding_model: str = "text-embedding-ada-002",
    llm_model: str = "gpt-4o-mini",
    app_id: str = "pyagenity_app",
    **kwargs: Any,
) -> Mem0Store:
    """Factory producing a Mem0Store configured for Qdrant backing."""
    config = {
        "vector_store": {
            "provider": "qdrant",
            "config": {
                "collection_name": collection_name,
                "url": qdrant_url,
                "api_key": qdrant_api_key,
                **kwargs.get("vector_store_config", {}),
            },
        },
        "embedder": {
            "provider": kwargs.get("embedder_provider", "openai"),
            "config": {"model": embedding_model, **kwargs.get("embedder_config", {})},
        },
        "llm": {
            "provider": kwargs.get("llm_provider", "openai"),
            "config": {"model": llm_model, **kwargs.get("llm_config", {})},
        },
    }
    return create_mem0_store(
        config=config,
        app_id=app_id,
    )
qdrant_store

Qdrant Vector Store Implementation for PyAgenity Framework

This module provides a modern, async-first implementation of BaseStore using Qdrant as the backend vector database. Supports both local and cloud Qdrant deployments with configurable embedding services.

Classes:

Name Description
QdrantStore

Modern async-first Qdrant-based vector store implementation.

Functions:

Name Description
create_cloud_qdrant_store

Create a cloud Qdrant store.

create_local_qdrant_store

Create a local Qdrant store.

create_remote_qdrant_store

Create a remote Qdrant store.

Attributes:

Name Type Description
logger
msg
Attributes
logger module-attribute
logger = getLogger(__name__)
msg module-attribute
msg = "Qdrant client not installed. Install with: pip install 'pyagenity[qdrant]'"
Classes
QdrantStore

Bases: BaseStore

Modern async-first Qdrant-based vector store implementation.

Features: - Async-only operations for better performance - Local and cloud Qdrant deployment support - Configurable embedding services - Efficient vector similarity search with multiple strategies - Collection management with automatic creation - Rich metadata filtering capabilities - User and agent-scoped operations

Example
# Local Qdrant with OpenAI embeddings
store = QdrantStore(path="./qdrant_data", embedding_service=OpenAIEmbeddingService())

# Remote Qdrant
store = QdrantStore(host="localhost", port=6333, embedding_service=OpenAIEmbeddingService())

# Cloud Qdrant
store = QdrantStore(
    url="https://xyz.qdrant.io",
    api_key="your-api-key",
    embedding_service=OpenAIEmbeddingService(),
)

Methods:

Name Description
__init__

Initialize Qdrant vector store.

adelete

Delete a memory by ID.

aforget_memory

Delete all memories for a user or agent.

aget

Get a specific memory by ID.

aget_all

Get all memories for a user.

arelease

Clean up resources.

asearch

Search memories by content similarity.

asetup

Set up the store and ensure default collection exists.

astore

Store a new memory.

aupdate

Update an existing memory.

delete

Synchronous wrapper for adelete that runs the async implementation.

forget_memory

Delete a memory by for a user or agent.

get

Synchronous wrapper for aget that runs the async implementation.

get_all

Synchronous wrapper for aget that runs the async implementation.

release

Clean up any resources used by the store (override in subclasses if needed).

search

Synchronous wrapper for asearch that runs the async implementation.

setup

Synchronous setup method for checkpointer.

store

Synchronous wrapper for astore that runs the async implementation.

update

Synchronous wrapper for aupdate that runs the async implementation.

Attributes:

Name Type Description
client
default_collection
embedding
Source code in pyagenity/store/qdrant_store.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
class QdrantStore(BaseStore):
    """
    Modern async-first Qdrant-based vector store implementation.

    Features:
    - Async-only operations for better performance
    - Local and cloud Qdrant deployment support
    - Configurable embedding services
    - Efficient vector similarity search with multiple strategies
    - Collection management with automatic creation
    - Rich metadata filtering capabilities
    - User and agent-scoped operations

    Example:
        ```python
        # Local Qdrant with OpenAI embeddings
        store = QdrantStore(path="./qdrant_data", embedding_service=OpenAIEmbeddingService())

        # Remote Qdrant
        store = QdrantStore(host="localhost", port=6333, embedding_service=OpenAIEmbeddingService())

        # Cloud Qdrant
        store = QdrantStore(
            url="https://xyz.qdrant.io",
            api_key="your-api-key",
            embedding_service=OpenAIEmbeddingService(),
        )
        ```
    """

    def __init__(
        self,
        embedding: BaseEmbedding,
        path: str | None = None,
        host: str | None = None,
        port: int | None = None,
        url: str | None = None,
        api_key: str | None = None,
        default_collection: str = "pyagenity_memories",
        distance_metric: DistanceMetric = DistanceMetric.COSINE,
        **kwargs: Any,
    ):
        """
        Initialize Qdrant vector store.

        Args:
            embedding: Service for generating embeddings
            path: Path for local Qdrant (file-based storage)
            host: Host for remote Qdrant server
            port: Port for remote Qdrant server
            url: URL for Qdrant cloud
            api_key: API key for Qdrant cloud
            default_collection: Default collection name
            distance_metric: Default distance metric
            **kwargs: Additional client parameters
        """
        self.embedding = embedding

        # Initialize async client
        if path:
            self.client = AsyncQdrantClient(path=path, **kwargs)
        elif url:
            self.client = AsyncQdrantClient(url=url, api_key=api_key, **kwargs)
        else:
            host = host or "localhost"
            port = port or 6333
            self.client = AsyncQdrantClient(host=host, port=port, api_key=api_key, **kwargs)

        # Cache for collection existence checks
        self._collection_cache = set()
        self._setup_lock = asyncio.Lock()

        self.default_collection = default_collection
        self._default_distance_metric = distance_metric

        logger.info(f"Initialized QdrantStore with config: path={path}, host={host}, url={url}")

    async def asetup(self) -> Any:
        """Set up the store and ensure default collection exists."""
        async with self._setup_lock:
            await self._ensure_collection_exists(self.default_collection)
        return True

    def _distance_metric_to_qdrant(self, metric: DistanceMetric) -> Distance:
        """Convert framework distance metric to Qdrant distance."""
        mapping = {
            DistanceMetric.COSINE: Distance.COSINE,
            DistanceMetric.EUCLIDEAN: Distance.EUCLID,
            DistanceMetric.DOT_PRODUCT: Distance.DOT,
            DistanceMetric.MANHATTAN: Distance.MANHATTAN,
        }
        return mapping.get(metric, Distance.COSINE)

    def _extract_config_values(self, config: dict[str, Any]) -> tuple[str | None, str | None, str]:
        """Extract user_id, thread_id, and collection from config."""
        user_id = config.get("user_id")
        thread_id = config.get("thread_id")
        collection = config.get("collection", self.default_collection)
        return user_id, thread_id, collection

    def _point_to_search_result(self, point) -> MemorySearchResult:
        """Convert Qdrant point to MemorySearchResult."""
        payload = getattr(point, "payload", {}) or {}

        # Extract content
        content = payload.get("content", "")

        # Convert memory_type string back to enum
        memory_type_str = payload.get("memory_type", "episodic")
        try:
            memory_type = MemoryType(memory_type_str)
        except ValueError:
            memory_type = MemoryType.EPISODIC

        # Parse timestamp
        timestamp_str = payload.get("timestamp")
        timestamp = None
        if timestamp_str:
            try:
                timestamp = datetime.fromisoformat(timestamp_str)
            except (ValueError, TypeError):
                timestamp = None

        return MemorySearchResult(
            id=str(point.id),
            content=content,
            score=float(getattr(point, "score", 1.0) or 0.0),
            memory_type=memory_type,
            metadata=payload,
            vector=getattr(point, "vector", None),
            user_id=payload.get("user_id"),
            thread_id=payload.get("thread_id")
            or payload.get("agent_id"),  # Support both thread_id and agent_id
            timestamp=timestamp,
        )

    def _build_qdrant_filter(
        self,
        user_id: str | None = None,
        thread_id: str | None = None,
        memory_type: MemoryType | None = None,
        category: str | None = None,
        filters: dict[str, Any] | None = None,
    ) -> Filter | None:
        """Build Qdrant filter from parameters."""
        conditions = []

        # Add user/agent filters
        if user_id:
            conditions.append(
                FieldCondition(
                    key="user_id",
                    match=MatchValue(value=user_id),
                ),
            )
        if thread_id:
            conditions.append(
                FieldCondition(
                    key="thread_id",
                    match=MatchValue(value=thread_id),
                ),
            )
        if memory_type:
            conditions.append(
                FieldCondition(
                    key="memory_type",
                    match=MatchValue(value=memory_type.value),
                )
            )
        if category:
            conditions.append(FieldCondition(key="category", match=MatchValue(value=category)))

        # Add custom filters
        if filters:
            for key, value in filters.items():
                if isinstance(value, str | int | bool):
                    conditions.append(FieldCondition(key=key, match=MatchValue(value=value)))

        return Filter(must=conditions) if conditions else None

    async def _ensure_collection_exists(self, collection: str) -> None:
        """Ensure collection exists, create if not."""
        if collection in self._collection_cache:
            return

        try:
            # Check if collection exists
            collections = await self.client.get_collections()
            existing_names = {col.name for col in collections.collections}

            if collection not in existing_names:
                # Create collection with vector configuration
                await self.client.create_collection(
                    collection_name=collection,
                    vectors_config=VectorParams(
                        size=self.embedding.dimension,
                        distance=self._distance_metric_to_qdrant(
                            self._default_distance_metric,
                        ),
                    ),
                )
                logger.info(f"Created collection: {collection}")

            self._collection_cache.add(collection)
        except Exception as e:
            logger.error(f"Error ensuring collection {collection} exists: {e}")
            raise

    def _prepare_content(self, content: str | Message) -> str:
        """Extract text content from string or Message."""
        if isinstance(content, Message):
            return content.text()
        return content

    def _create_memory_record(
        self,
        content: str | Message,
        user_id: str | None = None,
        thread_id: str | None = None,
        memory_type: MemoryType = MemoryType.EPISODIC,
        category: str = "general",
        metadata: dict[str, Any] | None = None,
    ) -> MemoryRecord:
        """Create a memory record from parameters."""
        text_content = self._prepare_content(content)

        if isinstance(content, Message):
            return MemoryRecord.from_message(
                content,
                user_id=user_id,
                thread_id=thread_id,
                additional_metadata=metadata,
            )

        return MemoryRecord(
            content=text_content,
            user_id=user_id,
            thread_id=thread_id,
            memory_type=memory_type,
            metadata=metadata or {},
            category=category,
        )

    # --- BaseStore abstract method implementations ---

    async def astore(
        self,
        config: dict[str, Any],
        content: str | Message,
        memory_type: MemoryType = MemoryType.EPISODIC,
        category: str = "general",
        metadata: dict[str, Any] | None = None,
        **kwargs,
    ) -> str:
        """Store a new memory."""
        user_id, thread_id, collection = self._extract_config_values(config)

        # Ensure collection exists
        await self._ensure_collection_exists(collection)

        # Create memory record
        record = self._create_memory_record(
            content=content,
            user_id=user_id,
            thread_id=thread_id,
            memory_type=memory_type,
            category=category,
            metadata=metadata,
        )

        # Generate embedding
        text_content = self._prepare_content(content)
        vector = await self.embedding.aembed(text_content)
        if not vector or len(vector) != self.embedding.dimension:
            raise ValueError("Embedding service returned invalid vector")

        # Prepare payload
        payload = {
            "content": record.content,
            "user_id": record.user_id,
            "thread_id": record.thread_id,
            "memory_type": record.memory_type.value,
            "category": record.category,
            "timestamp": record.timestamp.isoformat() if record.timestamp else None,
            **record.metadata,
        }

        # Create point
        point = PointStruct(
            id=record.id,
            vector=vector,
            payload=payload,
        )

        # Store in Qdrant
        await self.client.upsert(
            collection_name=collection,
            points=[point],
        )

        logger.debug(f"Stored memory {record.id} in collection {collection}")
        return record.id

    async def asearch(
        self,
        config: dict[str, Any],
        query: str,
        memory_type: MemoryType | None = None,
        category: str | None = None,
        limit: int = 10,
        score_threshold: float | None = None,
        filters: dict[str, Any] | None = None,
        retrieval_strategy: RetrievalStrategy = RetrievalStrategy.SIMILARITY,
        distance_metric: DistanceMetric = DistanceMetric.COSINE,
        max_tokens: int = 4000,
        **kwargs,
    ) -> list[MemorySearchResult]:
        """Search memories by content similarity."""
        user_id, thread_id, collection = self._extract_config_values(config)

        # Ensure collection exists
        await self._ensure_collection_exists(collection)

        # Generate query embedding
        query_vector = await self.embedding.aembed(query)
        if not query_vector or len(query_vector) != self.embedding.dimension:
            raise ValueError("Embedding service returned invalid vector")

        # Build filter
        search_filter = self._build_qdrant_filter(
            user_id=user_id,
            thread_id=thread_id,
            memory_type=memory_type,
            category=category,
            filters=filters,
        )

        # Perform search
        search_result = await self.client.search(
            collection_name=collection,
            query_vector=query_vector,
            query_filter=search_filter,
            limit=limit,
            score_threshold=score_threshold,
        )

        # Convert to search results
        results = [self._point_to_search_result(point) for point in search_result]

        logger.debug(f"Found {len(results)} memories for query in collection {collection}")
        return results

    async def aget(
        self,
        config: dict[str, Any],
        memory_id: str,
        **kwargs,
    ) -> MemorySearchResult | None:
        """Get a specific memory by ID."""
        user_id, thread_id, collection = self._extract_config_values(config)

        try:
            # Ensure collection exists
            await self._ensure_collection_exists(collection)

            # Get point by ID
            points = await self.client.retrieve(
                collection_name=collection,
                ids=[memory_id],
            )

            if not points:
                return None

            point = points[0]
            result = self._point_to_search_result(point)

            # Verify user/agent access if specified
            if user_id and result.user_id != user_id:
                return None
            if thread_id and result.thread_id != thread_id:
                return None

            return result

        except Exception as e:
            logger.error(f"Error retrieving memory {memory_id}: {e}")
            return None

    async def aget_all(
        self,
        config: dict[str, Any],
        limit: int = 100,
        **kwargs,
    ) -> list[MemorySearchResult]:
        """Get all memories for a user."""
        user_id, _, collection = self._extract_config_values(config)

        # Ensure collection exists
        await self._ensure_collection_exists(collection)

        # Build filter
        search_filter = self._build_qdrant_filter(
            user_id=user_id,
        )

        # Perform search
        search_result = await self.client.search(
            collection_name=collection,
            query_vector=[],
            query_filter=search_filter,
            limit=limit,
        )

        # Convert to search results
        results = [self._point_to_search_result(point) for point in search_result]

        logger.debug(f"Found {len(results)} memories for query in collection {collection}")
        return results

    async def aupdate(
        self,
        config: dict[str, Any],
        memory_id: str,
        content: str | Message,
        metadata: dict[str, Any] | None = None,
        **kwargs,
    ) -> None:
        """Update an existing memory."""
        user_id, thread_id, collection = self._extract_config_values(config)

        # Get existing memory
        existing = await self.aget(config, memory_id)
        if not existing:
            raise ValueError(f"Memory {memory_id} not found")

        # Verify user/agent access if specified
        if user_id and existing.user_id != user_id:
            raise PermissionError("User does not have permission to update this memory")
        if thread_id and existing.thread_id != thread_id:
            raise PermissionError("Thread does not have permission to update this memory")

        # Prepare new content
        text_content = self._prepare_content(content)
        new_vector = await self.embedding.aembed(text_content)
        if not new_vector or len(new_vector) != self.embedding.dimension:
            raise ValueError("Embedding service returned invalid vector")

        # Update payload
        updated_metadata = {**existing.metadata}
        if metadata:
            updated_metadata.update(metadata)

        updated_payload = {
            "content": text_content,
            "user_id": existing.user_id,
            "thread_id": existing.thread_id,
            "memory_type": existing.memory_type.value,
            "category": updated_metadata.get("category", "general"),
            "timestamp": datetime.now().isoformat(),
            **updated_metadata,
        }

        # Create updated point
        point = PointStruct(
            id=memory_id,
            vector=new_vector,
            payload=updated_payload,
        )

        # Update in Qdrant
        await self.client.upsert(
            collection_name=collection,
            points=[point],
        )

        logger.debug(f"Updated memory {memory_id} in collection {collection}")

    async def adelete(
        self,
        config: dict[str, Any],
        memory_id: str,
        **kwargs,
    ) -> None:
        """Delete a memory by ID."""
        user_id, thread_id, collection = self._extract_config_values(config)

        # Verify memory exists and user has access
        existing = await self.aget(config, memory_id)
        if not existing:
            raise ValueError(f"Memory {memory_id} not found")

        # verify user/agent access if specified
        if user_id and existing.user_id != user_id:
            raise PermissionError("User does not have permission to delete this memory")
        if thread_id and existing.thread_id != thread_id:
            raise PermissionError("Thread does not have permission to delete this memory")

        # Delete from Qdrant
        await self.client.delete(
            collection_name=collection,
            points_selector=models.PointIdsList(points=[memory_id]),
        )

        logger.debug(f"Deleted memory {memory_id} from collection {collection}")

    async def aforget_memory(
        self,
        config: dict[str, Any],
        **kwargs,
    ) -> None:
        """Delete all memories for a user or agent."""
        user_id, agent_id, collection = self._extract_config_values(config)

        # Build filter for memories to delete
        delete_filter = self._build_qdrant_filter(user_id=user_id, thread_id=agent_id)

        if delete_filter:
            # Delete matching memories
            await self.client.delete(
                collection_name=collection,
                points_selector=models.FilterSelector(filter=delete_filter),
            )

            logger.info(
                f"Deleted all memories for user_id={user_id}, agent_id={agent_id} "
                f"in collection {collection}"
            )
        else:
            logger.warning("No user_id or agent_id specified for memory deletion")

    async def arelease(self) -> None:
        """Clean up resources."""
        if hasattr(self.client, "close"):
            await self.client.close()
        logger.info("QdrantStore resources released")
Attributes
client instance-attribute
client = AsyncQdrantClient(path=path, **kwargs)
default_collection instance-attribute
default_collection = default_collection
embedding instance-attribute
embedding = embedding
Functions
__init__
__init__(embedding, path=None, host=None, port=None, url=None, api_key=None, default_collection='pyagenity_memories', distance_metric=DistanceMetric.COSINE, **kwargs)

Initialize Qdrant vector store.

Parameters:

Name Type Description Default
embedding BaseEmbedding

Service for generating embeddings

required
path str | None

Path for local Qdrant (file-based storage)

None
host str | None

Host for remote Qdrant server

None
port int | None

Port for remote Qdrant server

None
url str | None

URL for Qdrant cloud

None
api_key str | None

API key for Qdrant cloud

None
default_collection str

Default collection name

'pyagenity_memories'
distance_metric DistanceMetric

Default distance metric

COSINE
**kwargs Any

Additional client parameters

{}
Source code in pyagenity/store/qdrant_store.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def __init__(
    self,
    embedding: BaseEmbedding,
    path: str | None = None,
    host: str | None = None,
    port: int | None = None,
    url: str | None = None,
    api_key: str | None = None,
    default_collection: str = "pyagenity_memories",
    distance_metric: DistanceMetric = DistanceMetric.COSINE,
    **kwargs: Any,
):
    """
    Initialize Qdrant vector store.

    Args:
        embedding: Service for generating embeddings
        path: Path for local Qdrant (file-based storage)
        host: Host for remote Qdrant server
        port: Port for remote Qdrant server
        url: URL for Qdrant cloud
        api_key: API key for Qdrant cloud
        default_collection: Default collection name
        distance_metric: Default distance metric
        **kwargs: Additional client parameters
    """
    self.embedding = embedding

    # Initialize async client
    if path:
        self.client = AsyncQdrantClient(path=path, **kwargs)
    elif url:
        self.client = AsyncQdrantClient(url=url, api_key=api_key, **kwargs)
    else:
        host = host or "localhost"
        port = port or 6333
        self.client = AsyncQdrantClient(host=host, port=port, api_key=api_key, **kwargs)

    # Cache for collection existence checks
    self._collection_cache = set()
    self._setup_lock = asyncio.Lock()

    self.default_collection = default_collection
    self._default_distance_metric = distance_metric

    logger.info(f"Initialized QdrantStore with config: path={path}, host={host}, url={url}")
adelete async
adelete(config, memory_id, **kwargs)

Delete a memory by ID.

Source code in pyagenity/store/qdrant_store.py
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
async def adelete(
    self,
    config: dict[str, Any],
    memory_id: str,
    **kwargs,
) -> None:
    """Delete a memory by ID."""
    user_id, thread_id, collection = self._extract_config_values(config)

    # Verify memory exists and user has access
    existing = await self.aget(config, memory_id)
    if not existing:
        raise ValueError(f"Memory {memory_id} not found")

    # verify user/agent access if specified
    if user_id and existing.user_id != user_id:
        raise PermissionError("User does not have permission to delete this memory")
    if thread_id and existing.thread_id != thread_id:
        raise PermissionError("Thread does not have permission to delete this memory")

    # Delete from Qdrant
    await self.client.delete(
        collection_name=collection,
        points_selector=models.PointIdsList(points=[memory_id]),
    )

    logger.debug(f"Deleted memory {memory_id} from collection {collection}")
aforget_memory async
aforget_memory(config, **kwargs)

Delete all memories for a user or agent.

Source code in pyagenity/store/qdrant_store.py
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
async def aforget_memory(
    self,
    config: dict[str, Any],
    **kwargs,
) -> None:
    """Delete all memories for a user or agent."""
    user_id, agent_id, collection = self._extract_config_values(config)

    # Build filter for memories to delete
    delete_filter = self._build_qdrant_filter(user_id=user_id, thread_id=agent_id)

    if delete_filter:
        # Delete matching memories
        await self.client.delete(
            collection_name=collection,
            points_selector=models.FilterSelector(filter=delete_filter),
        )

        logger.info(
            f"Deleted all memories for user_id={user_id}, agent_id={agent_id} "
            f"in collection {collection}"
        )
    else:
        logger.warning("No user_id or agent_id specified for memory deletion")
aget async
aget(config, memory_id, **kwargs)

Get a specific memory by ID.

Source code in pyagenity/store/qdrant_store.py
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
async def aget(
    self,
    config: dict[str, Any],
    memory_id: str,
    **kwargs,
) -> MemorySearchResult | None:
    """Get a specific memory by ID."""
    user_id, thread_id, collection = self._extract_config_values(config)

    try:
        # Ensure collection exists
        await self._ensure_collection_exists(collection)

        # Get point by ID
        points = await self.client.retrieve(
            collection_name=collection,
            ids=[memory_id],
        )

        if not points:
            return None

        point = points[0]
        result = self._point_to_search_result(point)

        # Verify user/agent access if specified
        if user_id and result.user_id != user_id:
            return None
        if thread_id and result.thread_id != thread_id:
            return None

        return result

    except Exception as e:
        logger.error(f"Error retrieving memory {memory_id}: {e}")
        return None
aget_all async
aget_all(config, limit=100, **kwargs)

Get all memories for a user.

Source code in pyagenity/store/qdrant_store.py
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
async def aget_all(
    self,
    config: dict[str, Any],
    limit: int = 100,
    **kwargs,
) -> list[MemorySearchResult]:
    """Get all memories for a user."""
    user_id, _, collection = self._extract_config_values(config)

    # Ensure collection exists
    await self._ensure_collection_exists(collection)

    # Build filter
    search_filter = self._build_qdrant_filter(
        user_id=user_id,
    )

    # Perform search
    search_result = await self.client.search(
        collection_name=collection,
        query_vector=[],
        query_filter=search_filter,
        limit=limit,
    )

    # Convert to search results
    results = [self._point_to_search_result(point) for point in search_result]

    logger.debug(f"Found {len(results)} memories for query in collection {collection}")
    return results
arelease async
arelease()

Clean up resources.

Source code in pyagenity/store/qdrant_store.py
576
577
578
579
580
async def arelease(self) -> None:
    """Clean up resources."""
    if hasattr(self.client, "close"):
        await self.client.close()
    logger.info("QdrantStore resources released")
asearch async
asearch(config, query, memory_type=None, category=None, limit=10, score_threshold=None, filters=None, retrieval_strategy=RetrievalStrategy.SIMILARITY, distance_metric=DistanceMetric.COSINE, max_tokens=4000, **kwargs)

Search memories by content similarity.

Source code in pyagenity/store/qdrant_store.py
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
async def asearch(
    self,
    config: dict[str, Any],
    query: str,
    memory_type: MemoryType | None = None,
    category: str | None = None,
    limit: int = 10,
    score_threshold: float | None = None,
    filters: dict[str, Any] | None = None,
    retrieval_strategy: RetrievalStrategy = RetrievalStrategy.SIMILARITY,
    distance_metric: DistanceMetric = DistanceMetric.COSINE,
    max_tokens: int = 4000,
    **kwargs,
) -> list[MemorySearchResult]:
    """Search memories by content similarity."""
    user_id, thread_id, collection = self._extract_config_values(config)

    # Ensure collection exists
    await self._ensure_collection_exists(collection)

    # Generate query embedding
    query_vector = await self.embedding.aembed(query)
    if not query_vector or len(query_vector) != self.embedding.dimension:
        raise ValueError("Embedding service returned invalid vector")

    # Build filter
    search_filter = self._build_qdrant_filter(
        user_id=user_id,
        thread_id=thread_id,
        memory_type=memory_type,
        category=category,
        filters=filters,
    )

    # Perform search
    search_result = await self.client.search(
        collection_name=collection,
        query_vector=query_vector,
        query_filter=search_filter,
        limit=limit,
        score_threshold=score_threshold,
    )

    # Convert to search results
    results = [self._point_to_search_result(point) for point in search_result]

    logger.debug(f"Found {len(results)} memories for query in collection {collection}")
    return results
asetup async
asetup()

Set up the store and ensure default collection exists.

Source code in pyagenity/store/qdrant_store.py
122
123
124
125
126
async def asetup(self) -> Any:
    """Set up the store and ensure default collection exists."""
    async with self._setup_lock:
        await self._ensure_collection_exists(self.default_collection)
    return True
astore async
astore(config, content, memory_type=MemoryType.EPISODIC, category='general', metadata=None, **kwargs)

Store a new memory.

Source code in pyagenity/store/qdrant_store.py
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
async def astore(
    self,
    config: dict[str, Any],
    content: str | Message,
    memory_type: MemoryType = MemoryType.EPISODIC,
    category: str = "general",
    metadata: dict[str, Any] | None = None,
    **kwargs,
) -> str:
    """Store a new memory."""
    user_id, thread_id, collection = self._extract_config_values(config)

    # Ensure collection exists
    await self._ensure_collection_exists(collection)

    # Create memory record
    record = self._create_memory_record(
        content=content,
        user_id=user_id,
        thread_id=thread_id,
        memory_type=memory_type,
        category=category,
        metadata=metadata,
    )

    # Generate embedding
    text_content = self._prepare_content(content)
    vector = await self.embedding.aembed(text_content)
    if not vector or len(vector) != self.embedding.dimension:
        raise ValueError("Embedding service returned invalid vector")

    # Prepare payload
    payload = {
        "content": record.content,
        "user_id": record.user_id,
        "thread_id": record.thread_id,
        "memory_type": record.memory_type.value,
        "category": record.category,
        "timestamp": record.timestamp.isoformat() if record.timestamp else None,
        **record.metadata,
    }

    # Create point
    point = PointStruct(
        id=record.id,
        vector=vector,
        payload=payload,
    )

    # Store in Qdrant
    await self.client.upsert(
        collection_name=collection,
        points=[point],
    )

    logger.debug(f"Stored memory {record.id} in collection {collection}")
    return record.id
aupdate async
aupdate(config, memory_id, content, metadata=None, **kwargs)

Update an existing memory.

Source code in pyagenity/store/qdrant_store.py
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
async def aupdate(
    self,
    config: dict[str, Any],
    memory_id: str,
    content: str | Message,
    metadata: dict[str, Any] | None = None,
    **kwargs,
) -> None:
    """Update an existing memory."""
    user_id, thread_id, collection = self._extract_config_values(config)

    # Get existing memory
    existing = await self.aget(config, memory_id)
    if not existing:
        raise ValueError(f"Memory {memory_id} not found")

    # Verify user/agent access if specified
    if user_id and existing.user_id != user_id:
        raise PermissionError("User does not have permission to update this memory")
    if thread_id and existing.thread_id != thread_id:
        raise PermissionError("Thread does not have permission to update this memory")

    # Prepare new content
    text_content = self._prepare_content(content)
    new_vector = await self.embedding.aembed(text_content)
    if not new_vector or len(new_vector) != self.embedding.dimension:
        raise ValueError("Embedding service returned invalid vector")

    # Update payload
    updated_metadata = {**existing.metadata}
    if metadata:
        updated_metadata.update(metadata)

    updated_payload = {
        "content": text_content,
        "user_id": existing.user_id,
        "thread_id": existing.thread_id,
        "memory_type": existing.memory_type.value,
        "category": updated_metadata.get("category", "general"),
        "timestamp": datetime.now().isoformat(),
        **updated_metadata,
    }

    # Create updated point
    point = PointStruct(
        id=memory_id,
        vector=new_vector,
        payload=updated_payload,
    )

    # Update in Qdrant
    await self.client.upsert(
        collection_name=collection,
        points=[point],
    )

    logger.debug(f"Updated memory {memory_id} in collection {collection}")
delete
delete(config, memory_id, **kwargs)

Synchronous wrapper for adelete that runs the async implementation.

Source code in pyagenity/store/base_store.py
247
248
249
def delete(self, config: dict[str, Any], memory_id: str, **kwargs) -> None:
    """Synchronous wrapper for `adelete` that runs the async implementation."""
    return run_coroutine(self.adelete(config, memory_id, **kwargs))
forget_memory
forget_memory(config, **kwargs)

Delete a memory by for a user or agent.

Source code in pyagenity/store/base_store.py
260
261
262
263
264
265
266
def forget_memory(
    self,
    config: dict[str, Any],
    **kwargs,
) -> Any:
    """Delete a memory by for a user or agent."""
    return run_coroutine(self.aforget_memory(config, **kwargs))
get
get(config, memory_id, **kwargs)

Synchronous wrapper for aget that runs the async implementation.

Source code in pyagenity/store/base_store.py
193
194
195
def get(self, config: dict[str, Any], memory_id: str, **kwargs) -> MemorySearchResult | None:
    """Synchronous wrapper for `aget` that runs the async implementation."""
    return run_coroutine(self.aget(config, memory_id, **kwargs))
get_all
get_all(config, limit=100, **kwargs)

Synchronous wrapper for aget that runs the async implementation.

Source code in pyagenity/store/base_store.py
197
198
199
200
201
202
203
204
def get_all(
    self,
    config: dict[str, Any],
    limit: int = 100,
    **kwargs,
) -> list[MemorySearchResult]:
    """Synchronous wrapper for `aget` that runs the async implementation."""
    return run_coroutine(self.aget_all(config, limit=limit, **kwargs))
release
release()

Clean up any resources used by the store (override in subclasses if needed).

Source code in pyagenity/store/base_store.py
274
275
276
def release(self) -> None:
    """Clean up any resources used by the store (override in subclasses if needed)."""
    return run_coroutine(self.arelease())
search
search(config, query, memory_type=None, category=None, limit=10, score_threshold=None, filters=None, retrieval_strategy=RetrievalStrategy.SIMILARITY, distance_metric=DistanceMetric.COSINE, max_tokens=4000, **kwargs)

Synchronous wrapper for asearch that runs the async implementation.

Source code in pyagenity/store/base_store.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def search(
    self,
    config: dict[str, Any],
    query: str,
    memory_type: MemoryType | None = None,
    category: str | None = None,
    limit: int = 10,
    score_threshold: float | None = None,
    filters: dict[str, Any] | None = None,
    retrieval_strategy: RetrievalStrategy = RetrievalStrategy.SIMILARITY,
    distance_metric: DistanceMetric = DistanceMetric.COSINE,
    max_tokens: int = 4000,
    **kwargs,
) -> list[MemorySearchResult]:
    """Synchronous wrapper for `asearch` that runs the async implementation."""
    return run_coroutine(
        self.asearch(
            config,
            query,
            memory_type=memory_type,
            category=category,
            limit=limit,
            score_threshold=score_threshold,
            filters=filters,
            retrieval_strategy=retrieval_strategy,
            distance_metric=distance_metric,
            max_tokens=max_tokens,
            **kwargs,
        )
    )
setup
setup()

Synchronous setup method for checkpointer.

Returns:

Name Type Description
Any Any

Implementation-defined setup result.

Source code in pyagenity/store/base_store.py
39
40
41
42
43
44
45
46
def setup(self) -> Any:
    """
    Synchronous setup method for checkpointer.

    Returns:
        Any: Implementation-defined setup result.
    """
    return run_coroutine(self.asetup())
store
store(config, content, memory_type=MemoryType.EPISODIC, category='general', metadata=None, **kwargs)

Synchronous wrapper for astore that runs the async implementation.

Source code in pyagenity/store/base_store.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def store(
    self,
    config: dict[str, Any],
    content: str | Message,
    memory_type: MemoryType = MemoryType.EPISODIC,
    category: str = "general",
    metadata: dict[str, Any] | None = None,
    **kwargs,
) -> str:
    """Synchronous wrapper for `astore` that runs the async implementation."""
    return run_coroutine(
        self.astore(
            config,
            content,
            memory_type=memory_type,
            category=category,
            metadata=metadata,
            **kwargs,
        )
    )
update
update(config, memory_id, content, metadata=None, **kwargs)

Synchronous wrapper for aupdate that runs the async implementation.

Source code in pyagenity/store/base_store.py
226
227
228
229
230
231
232
233
234
235
def update(
    self,
    config: dict[str, Any],
    memory_id: str,
    content: str | Message,
    metadata: dict[str, Any] | None = None,
    **kwargs,
) -> Any:
    """Synchronous wrapper for `aupdate` that runs the async implementation."""
    return run_coroutine(self.aupdate(config, memory_id, content, metadata=metadata, **kwargs))
Functions
create_cloud_qdrant_store
create_cloud_qdrant_store(url, api_key, embedding, **kwargs)

Create a cloud Qdrant store.

Source code in pyagenity/store/qdrant_store.py
614
615
616
617
618
619
620
621
622
623
624
625
626
def create_cloud_qdrant_store(
    url: str,
    api_key: str,
    embedding: BaseEmbedding,
    **kwargs,
) -> QdrantStore:
    """Create a cloud Qdrant store."""
    return QdrantStore(
        embedding=embedding,
        url=url,
        api_key=api_key,
        **kwargs,
    )
create_local_qdrant_store
create_local_qdrant_store(path, embedding, **kwargs)

Create a local Qdrant store.

Source code in pyagenity/store/qdrant_store.py
586
587
588
589
590
591
592
593
594
595
596
def create_local_qdrant_store(
    path: str,
    embedding: BaseEmbedding,
    **kwargs,
) -> QdrantStore:
    """Create a local Qdrant store."""
    return QdrantStore(
        embedding=embedding,
        path=path,
        **kwargs,
    )
create_remote_qdrant_store
create_remote_qdrant_store(host, port, embedding, **kwargs)

Create a remote Qdrant store.

Source code in pyagenity/store/qdrant_store.py
599
600
601
602
603
604
605
606
607
608
609
610
611
def create_remote_qdrant_store(
    host: str,
    port: int,
    embedding: BaseEmbedding,
    **kwargs,
) -> QdrantStore:
    """Create a remote Qdrant store."""
    return QdrantStore(
        embedding=embedding,
        host=host,
        port=port,
        **kwargs,
    )
store_schema

Classes:

Name Description
DistanceMetric

Supported distance metrics for vector similarity.

MemoryRecord

Comprehensive memory record for storage (Pydantic model).

MemorySearchResult

Result from a memory search operation (Pydantic model).

MemoryType

Types of memories that can be stored.

RetrievalStrategy

Memory retrieval strategies.

Classes
DistanceMetric

Bases: Enum

Supported distance metrics for vector similarity.

Attributes:

Name Type Description
COSINE
DOT_PRODUCT
EUCLIDEAN
MANHATTAN
Source code in pyagenity/store/store_schema.py
21
22
23
24
25
26
27
class DistanceMetric(Enum):
    """Supported distance metrics for vector similarity."""

    COSINE = "cosine"
    EUCLIDEAN = "euclidean"
    DOT_PRODUCT = "dot_product"
    MANHATTAN = "manhattan"
Attributes
COSINE class-attribute instance-attribute
COSINE = 'cosine'
DOT_PRODUCT class-attribute instance-attribute
DOT_PRODUCT = 'dot_product'
EUCLIDEAN class-attribute instance-attribute
EUCLIDEAN = 'euclidean'
MANHATTAN class-attribute instance-attribute
MANHATTAN = 'manhattan'
MemoryRecord

Bases: BaseModel

Comprehensive memory record for storage (Pydantic model).

Methods:

Name Description
from_message
validate_vector

Attributes:

Name Type Description
category str
content str
id str
memory_type MemoryType
metadata dict[str, Any]
thread_id str | None
timestamp datetime | None
user_id str | None
vector list[float] | None
Source code in pyagenity/store/store_schema.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
class MemoryRecord(BaseModel):
    """Comprehensive memory record for storage (Pydantic model)."""

    id: str = Field(default_factory=lambda: str(uuid4()))
    content: str
    user_id: str | None = None
    thread_id: str | None = None
    memory_type: MemoryType = Field(default=MemoryType.EPISODIC)
    metadata: dict[str, Any] = Field(default_factory=dict)
    category: str = Field(default="general")
    vector: list[float] | None = None
    timestamp: datetime | None = Field(default_factory=datetime.now)

    @field_validator("vector")
    @classmethod
    def validate_vector(cls, v):
        if v is not None and (
            not isinstance(v, list) or any(not isinstance(x, (int | float)) for x in v)
        ):
            raise ValueError("vector must be list[float] or None")
        return v

    @classmethod
    def from_message(
        cls,
        message: Message,
        user_id: str | None = None,
        thread_id: str | None = None,
        vector: list[float] | None = None,
        additional_metadata: dict[str, Any] | None = None,
    ) -> "MemoryRecord":
        content = message.text()
        metadata = {
            "role": message.role,
            "message_id": str(message.message_id),
            "timestamp": message.timestamp.isoformat() if message.timestamp else None,
            "has_tool_calls": bool(message.tools_calls),
            "has_reasoning": bool(message.reasoning),
            "token_usage": message.usages.model_dump() if message.usages else None,
            **(additional_metadata or {}),
        }
        return cls(
            content=content,
            user_id=user_id,
            thread_id=thread_id,
            memory_type=MemoryType.EPISODIC,
            metadata=metadata,
            vector=vector,
        )
Attributes
category class-attribute instance-attribute
category = Field(default='general')
content instance-attribute
content
id class-attribute instance-attribute
id = Field(default_factory=lambda: str(uuid4()))
memory_type class-attribute instance-attribute
memory_type = Field(default=EPISODIC)
metadata class-attribute instance-attribute
metadata = Field(default_factory=dict)
thread_id class-attribute instance-attribute
thread_id = None
timestamp class-attribute instance-attribute
timestamp = Field(default_factory=now)
user_id class-attribute instance-attribute
user_id = None
vector class-attribute instance-attribute
vector = None
Functions
from_message classmethod
from_message(message, user_id=None, thread_id=None, vector=None, additional_metadata=None)
Source code in pyagenity/store/store_schema.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
@classmethod
def from_message(
    cls,
    message: Message,
    user_id: str | None = None,
    thread_id: str | None = None,
    vector: list[float] | None = None,
    additional_metadata: dict[str, Any] | None = None,
) -> "MemoryRecord":
    content = message.text()
    metadata = {
        "role": message.role,
        "message_id": str(message.message_id),
        "timestamp": message.timestamp.isoformat() if message.timestamp else None,
        "has_tool_calls": bool(message.tools_calls),
        "has_reasoning": bool(message.reasoning),
        "token_usage": message.usages.model_dump() if message.usages else None,
        **(additional_metadata or {}),
    }
    return cls(
        content=content,
        user_id=user_id,
        thread_id=thread_id,
        memory_type=MemoryType.EPISODIC,
        metadata=metadata,
        vector=vector,
    )
validate_vector classmethod
validate_vector(v)
Source code in pyagenity/store/store_schema.py
78
79
80
81
82
83
84
85
@field_validator("vector")
@classmethod
def validate_vector(cls, v):
    if v is not None and (
        not isinstance(v, list) or any(not isinstance(x, (int | float)) for x in v)
    ):
        raise ValueError("vector must be list[float] or None")
    return v
MemorySearchResult

Bases: BaseModel

Result from a memory search operation (Pydantic model).

Methods:

Name Description
validate_vector

Attributes:

Name Type Description
content str
id str
memory_type MemoryType
metadata dict[str, Any]
score float
thread_id str | None
timestamp datetime | None
user_id str | None
vector list[float] | None
Source code in pyagenity/store/store_schema.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class MemorySearchResult(BaseModel):
    """Result from a memory search operation (Pydantic model)."""

    id: str = Field(default_factory=lambda: str(uuid4()))
    content: str = Field(default="", description="Primary textual content of the memory")
    score: float = Field(default=0.0, ge=0.0, description="Similarity / relevance score")
    memory_type: MemoryType = Field(default=MemoryType.EPISODIC)
    metadata: dict[str, Any] = Field(default_factory=dict)
    vector: list[float] | None = Field(default=None)
    user_id: str | None = None
    thread_id: str | None = None
    timestamp: datetime | None = Field(default_factory=datetime.now)

    @field_validator("vector")
    @classmethod
    def validate_vector(cls, v):
        if v is not None and (
            not isinstance(v, list) or any(not isinstance(x, (int | float)) for x in v)
        ):
            raise ValueError("vector must be list[float] or None")
        return v
Attributes
content class-attribute instance-attribute
content = Field(default='', description='Primary textual content of the memory')
id class-attribute instance-attribute
id = Field(default_factory=lambda: str(uuid4()))
memory_type class-attribute instance-attribute
memory_type = Field(default=EPISODIC)
metadata class-attribute instance-attribute
metadata = Field(default_factory=dict)
score class-attribute instance-attribute
score = Field(default=0.0, ge=0.0, description='Similarity / relevance score')
thread_id class-attribute instance-attribute
thread_id = None
timestamp class-attribute instance-attribute
timestamp = Field(default_factory=now)
user_id class-attribute instance-attribute
user_id = None
vector class-attribute instance-attribute
vector = Field(default=None)
Functions
validate_vector classmethod
validate_vector(v)
Source code in pyagenity/store/store_schema.py
55
56
57
58
59
60
61
62
@field_validator("vector")
@classmethod
def validate_vector(cls, v):
    if v is not None and (
        not isinstance(v, list) or any(not isinstance(x, (int | float)) for x in v)
    ):
        raise ValueError("vector must be list[float] or None")
    return v
MemoryType

Bases: Enum

Types of memories that can be stored.

Attributes:

Name Type Description
CUSTOM
DECLARATIVE
ENTITY
EPISODIC
PROCEDURAL
RELATIONSHIP
SEMANTIC
Source code in pyagenity/store/store_schema.py
30
31
32
33
34
35
36
37
38
39
class MemoryType(Enum):
    """Types of memories that can be stored."""

    EPISODIC = "episodic"  # Conversation memories
    SEMANTIC = "semantic"  # Facts and knowledge
    PROCEDURAL = "procedural"  # How-to knowledge
    ENTITY = "entity"  # Entity-based memories
    RELATIONSHIP = "relationship"  # Entity relationships
    CUSTOM = "custom"  # Custom memory types
    DECLARATIVE = "declarative"  # Explicit facts and events
Attributes
CUSTOM class-attribute instance-attribute
CUSTOM = 'custom'
DECLARATIVE class-attribute instance-attribute
DECLARATIVE = 'declarative'
ENTITY class-attribute instance-attribute
ENTITY = 'entity'
EPISODIC class-attribute instance-attribute
EPISODIC = 'episodic'
PROCEDURAL class-attribute instance-attribute
PROCEDURAL = 'procedural'
RELATIONSHIP class-attribute instance-attribute
RELATIONSHIP = 'relationship'
SEMANTIC class-attribute instance-attribute
SEMANTIC = 'semantic'
RetrievalStrategy

Bases: Enum

Memory retrieval strategies.

Attributes:

Name Type Description
GRAPH_TRAVERSAL
HYBRID
RELEVANCE
SIMILARITY
TEMPORAL
Source code in pyagenity/store/store_schema.py
11
12
13
14
15
16
17
18
class RetrievalStrategy(Enum):
    """Memory retrieval strategies."""

    SIMILARITY = "similarity"  # Vector similarity search
    TEMPORAL = "temporal"  # Time-based retrieval
    RELEVANCE = "relevance"  # Relevance scoring
    HYBRID = "hybrid"  # Combined approaches
    GRAPH_TRAVERSAL = "graph_traversal"  # Knowledge graph navigation
Attributes
GRAPH_TRAVERSAL class-attribute instance-attribute
GRAPH_TRAVERSAL = 'graph_traversal'
HYBRID class-attribute instance-attribute
HYBRID = 'hybrid'
RELEVANCE class-attribute instance-attribute
RELEVANCE = 'relevance'
SIMILARITY class-attribute instance-attribute
SIMILARITY = 'similarity'
TEMPORAL class-attribute instance-attribute
TEMPORAL = 'temporal'

utils

Unified utility exports for PyAgenity agent graphs.

This module re-exports core utility symbols for agent graph construction, message handling, callback management, reducers, and constants. Import from this module for a stable, unified surface of agent utilities.

Main Exports
  • Message and content blocks (Message, TextBlock, ToolCallBlock, etc.)
  • Callback management (CallbackManager, register_before_invoke, etc.)
  • Command and callable utilities (Command, call_sync_or_async)
  • Reducers (add_messages, replace_messages, append_items, replace_value)
  • Constants (START, END, ExecutionState, etc.)
  • Converter (convert_messages)

Modules:

Name Description
background_task_manager

Background task manager for async operations in PyAgenity.

callable_utils

Utilities for calling sync or async functions in PyAgenity.

callbacks

Callback system for PyAgenity.

command

Command API for AgentGraph in PyAgenity.

constants

Constants and enums for PyAgenity agent graph execution and messaging.

converter

Message conversion utilities for PyAgenity agent graphs.

id_generator

ID Generator Module

logging

Centralized logging configuration for PyAgenity.

message

Message and content block primitives for agent graphs.

metrics

Lightweight metrics instrumentation utilities.

reducers

Reducer utilities for merging and replacing lists and values in agent state.

thread_info

Thread metadata and status tracking for agent graphs.

thread_name_generator

Thread name generation utilities for AI agent conversations.

Classes:

Name Description
AfterInvokeCallback

Abstract base class for after_invoke callbacks.

AnnotationBlock

Annotation content block for messages.

BeforeInvokeCallback

Abstract base class for before_invoke callbacks.

CallbackContext

Context information passed to callbacks.

CallbackManager

Manages registration and execution of callbacks for different invocation types.

Command

Command object that combines state updates with control flow.

DataBlock

Data content block for messages.

ErrorBlock

Error content block for messages.

ExecutionState

Graph execution states for agent workflows.

InvocationType

Types of invocations that can trigger callbacks.

Message

Represents a message in a conversation, including content, role, metadata, and token usage.

OnErrorCallback

Abstract base class for on_error callbacks.

ReasoningBlock

Reasoning content block for messages.

ResponseGranularity

Response granularity options for agent graph outputs.

StorageLevel

Message storage levels for agent state persistence.

TextBlock

Text content block for messages.

ThreadInfo

Metadata and status for a thread in agent execution.

TokenUsages

Tracks token usage statistics for a message or model response.

ToolCallBlock

Tool call content block for messages.

ToolResultBlock

Tool result content block for messages.

Functions:

Name Description
add_messages

Adds messages to the list, avoiding duplicates by message_id.

append_items

Appends items to a list, avoiding duplicates by item.id.

call_sync_or_async

Call a function that may be sync or async, returning its result.

convert_messages

Convert system prompts, agent state, and extra messages to a list of dicts for

register_after_invoke

Register an after_invoke callback on the global callback manager.

register_before_invoke

Register a before_invoke callback on the global callback manager.

register_on_error

Register an on_error callback on the global callback manager.

replace_messages

Replaces the entire message list with a new one.

replace_value

Replaces a value with another.

run_coroutine

Run an async coroutine from a sync context safely.

Attributes:

Name Type Description
ContentBlock
END Literal['__end__']
START Literal['__start__']
default_callback_manager

Attributes

ContentBlock module-attribute
ContentBlock = Annotated[Union[TextBlock, ImageBlock, AudioBlock, VideoBlock, DocumentBlock, DataBlock, ToolCallBlock, ToolResultBlock, ReasoningBlock, AnnotationBlock, ErrorBlock], Field(discriminator='type')]
END module-attribute
END = '__end__'
START module-attribute
START = '__start__'
__all__ module-attribute
__all__ = ['END', 'START', 'AfterInvokeCallback', 'AnnotationBlock', 'BeforeInvokeCallback', 'CallbackContext', 'CallbackManager', 'Command', 'ContentBlock', 'DataBlock', 'ErrorBlock', 'ExecutionState', 'InvocationType', 'Message', 'OnErrorCallback', 'ReasoningBlock', 'ResponseGranularity', 'StorageLevel', 'TextBlock', 'ThreadInfo', 'TokenUsages', 'ToolCallBlock', 'ToolResultBlock', 'add_messages', 'append_items', 'call_sync_or_async', 'convert_messages', 'default_callback_manager', 'register_after_invoke', 'register_before_invoke', 'register_on_error', 'replace_messages', 'replace_value', 'run_coroutine']
default_callback_manager module-attribute
default_callback_manager = CallbackManager()

Classes

AfterInvokeCallback

Bases: ABC

Abstract base class for after_invoke callbacks.

Called after the AI model, tool, or MCP function is invoked. Allows for output validation and modification.

Methods:

Name Description
__call__

Execute the after_invoke callback.

Source code in pyagenity/utils/callbacks.py
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
class AfterInvokeCallback[T, R](ABC):
    """Abstract base class for after_invoke callbacks.

    Called after the AI model, tool, or MCP function is invoked.
    Allows for output validation and modification.
    """

    @abstractmethod
    async def __call__(self, context: CallbackContext, input_data: T, output_data: Any) -> Any | R:
        """Execute the after_invoke callback.

        Args:
            context: Context information about the invocation
            input_data: The original input data that was sent
            output_data: The output data returned from the invocation

        Returns:
            Modified output data (can be same type or different type)

        Raises:
            Exception: If validation fails or modification cannot be performed
        """
        ...
Functions
__call__ abstractmethod async
__call__(context, input_data, output_data)

Execute the after_invoke callback.

Parameters:

Name Type Description Default
context CallbackContext

Context information about the invocation

required
input_data T

The original input data that was sent

required
output_data Any

The output data returned from the invocation

required

Returns:

Type Description
Any | R

Modified output data (can be same type or different type)

Raises:

Type Description
Exception

If validation fails or modification cannot be performed

Source code in pyagenity/utils/callbacks.py
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
@abstractmethod
async def __call__(self, context: CallbackContext, input_data: T, output_data: Any) -> Any | R:
    """Execute the after_invoke callback.

    Args:
        context: Context information about the invocation
        input_data: The original input data that was sent
        output_data: The output data returned from the invocation

    Returns:
        Modified output data (can be same type or different type)

    Raises:
        Exception: If validation fails or modification cannot be performed
    """
    ...
AnnotationBlock

Bases: BaseModel

Annotation content block for messages.

Attributes:

Name Type Description
type Literal['annotation']

Block type discriminator.

kind Literal['citation', 'note']

Kind of annotation.

refs list[AnnotationRef]

List of annotation references.

spans list[tuple[int, int]] | None

Spans covered by the annotation.

Source code in pyagenity/utils/message.py
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
class AnnotationBlock(BaseModel):
    """
    Annotation content block for messages.

    Attributes:
        type (Literal["annotation"]): Block type discriminator.
        kind (Literal["citation", "note"]): Kind of annotation.
        refs (list[AnnotationRef]): List of annotation references.
        spans (list[tuple[int, int]] | None): Spans covered by the annotation.
    """

    type: Literal["annotation"] = "annotation"
    kind: Literal["citation", "note"] = "citation"
    refs: list[AnnotationRef] = Field(default_factory=list)
    spans: list[tuple[int, int]] | None = None
Attributes
kind class-attribute instance-attribute
kind = 'citation'
refs class-attribute instance-attribute
refs = Field(default_factory=list)
spans class-attribute instance-attribute
spans = None
type class-attribute instance-attribute
type = 'annotation'
BeforeInvokeCallback

Bases: ABC

Abstract base class for before_invoke callbacks.

Called before the AI model, tool, or MCP function is invoked. Allows for input validation and modification.

Methods:

Name Description
__call__

Execute the before_invoke callback.

Source code in pyagenity/utils/callbacks.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class BeforeInvokeCallback[T, R](ABC):
    """Abstract base class for before_invoke callbacks.

    Called before the AI model, tool, or MCP function is invoked.
    Allows for input validation and modification.
    """

    @abstractmethod
    async def __call__(self, context: CallbackContext, input_data: T) -> T | R:
        """Execute the before_invoke callback.

        Args:
            context: Context information about the invocation
            input_data: The input data about to be sent to the invocation

        Returns:
            Modified input data (can be same type or different type)

        Raises:
            Exception: If validation fails or modification cannot be performed
        """
        ...
Functions
__call__ abstractmethod async
__call__(context, input_data)

Execute the before_invoke callback.

Parameters:

Name Type Description Default
context CallbackContext

Context information about the invocation

required
input_data T

The input data about to be sent to the invocation

required

Returns:

Type Description
T | R

Modified input data (can be same type or different type)

Raises:

Type Description
Exception

If validation fails or modification cannot be performed

Source code in pyagenity/utils/callbacks.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
@abstractmethod
async def __call__(self, context: CallbackContext, input_data: T) -> T | R:
    """Execute the before_invoke callback.

    Args:
        context: Context information about the invocation
        input_data: The input data about to be sent to the invocation

    Returns:
        Modified input data (can be same type or different type)

    Raises:
        Exception: If validation fails or modification cannot be performed
    """
    ...
CallbackContext dataclass

Context information passed to callbacks.

Methods:

Name Description
__init__

Attributes:

Name Type Description
function_name str | None
invocation_type InvocationType
metadata dict[str, Any] | None
node_name str
Source code in pyagenity/utils/callbacks.py
36
37
38
39
40
41
42
43
@dataclass
class CallbackContext:
    """Context information passed to callbacks."""

    invocation_type: InvocationType
    node_name: str
    function_name: str | None = None
    metadata: dict[str, Any] | None = None
Attributes
function_name class-attribute instance-attribute
function_name = None
invocation_type instance-attribute
invocation_type
metadata class-attribute instance-attribute
metadata = None
node_name instance-attribute
node_name
Functions
__init__
__init__(invocation_type, node_name, function_name=None, metadata=None)
CallbackManager

Manages registration and execution of callbacks for different invocation types.

Handles before_invoke, after_invoke, and on_error callbacks for AI, TOOL, and MCP invocations.

Methods:

Name Description
__init__

Initialize the CallbackManager with empty callback registries.

clear_callbacks

Clear callbacks for a specific invocation type or all types.

execute_after_invoke

Execute all after_invoke callbacks for the given context.

execute_before_invoke

Execute all before_invoke callbacks for the given context.

execute_on_error

Execute all on_error callbacks for the given context.

get_callback_counts

Get count of registered callbacks by type for debugging.

register_after_invoke

Register an after_invoke callback for a specific invocation type.

register_before_invoke

Register a before_invoke callback for a specific invocation type.

register_on_error

Register an on_error callback for a specific invocation type.

Source code in pyagenity/utils/callbacks.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
class CallbackManager:
    """
    Manages registration and execution of callbacks for different invocation types.

    Handles before_invoke, after_invoke, and on_error callbacks for AI, TOOL, and MCP invocations.
    """

    def __init__(self):
        """
        Initialize the CallbackManager with empty callback registries.
        """
        self._before_callbacks: dict[InvocationType, list[BeforeInvokeCallbackType]] = {
            InvocationType.AI: [],
            InvocationType.TOOL: [],
            InvocationType.MCP: [],
        }
        self._after_callbacks: dict[InvocationType, list[AfterInvokeCallbackType]] = {
            InvocationType.AI: [],
            InvocationType.TOOL: [],
            InvocationType.MCP: [],
        }
        self._error_callbacks: dict[InvocationType, list[OnErrorCallbackType]] = {
            InvocationType.AI: [],
            InvocationType.TOOL: [],
            InvocationType.MCP: [],
        }

    def register_before_invoke(
        self, invocation_type: InvocationType, callback: BeforeInvokeCallbackType
    ) -> None:
        """
        Register a before_invoke callback for a specific invocation type.

        Args:
            invocation_type (InvocationType): The type of invocation (AI, TOOL, MCP).
            callback (BeforeInvokeCallbackType): The callback to register.
        """
        self._before_callbacks[invocation_type].append(callback)

    def register_after_invoke(
        self, invocation_type: InvocationType, callback: AfterInvokeCallbackType
    ) -> None:
        """
        Register an after_invoke callback for a specific invocation type.

        Args:
            invocation_type (InvocationType): The type of invocation (AI, TOOL, MCP).
            callback (AfterInvokeCallbackType): The callback to register.
        """
        self._after_callbacks[invocation_type].append(callback)

    def register_on_error(
        self, invocation_type: InvocationType, callback: OnErrorCallbackType
    ) -> None:
        """
        Register an on_error callback for a specific invocation type.

        Args:
            invocation_type (InvocationType): The type of invocation (AI, TOOL, MCP).
            callback (OnErrorCallbackType): The callback to register.
        """
        self._error_callbacks[invocation_type].append(callback)

    async def execute_before_invoke(self, context: CallbackContext, input_data: Any) -> Any:
        """
        Execute all before_invoke callbacks for the given context.

        Args:
            context (CallbackContext): Context information about the invocation.
            input_data (Any): The input data to be validated or modified.

        Returns:
            Any: The modified input data after all callbacks.

        Raises:
            Exception: If any callback fails.
        """
        current_data = input_data

        for callback in self._before_callbacks[context.invocation_type]:
            try:
                if isinstance(callback, BeforeInvokeCallback):
                    current_data = await callback(context, current_data)
                elif callable(callback):
                    result = callback(context, current_data)
                    if hasattr(result, "__await__"):
                        current_data = await result
                    else:
                        current_data = result
            except Exception as e:
                await self.execute_on_error(context, input_data, e)
                raise

        return current_data

    async def execute_after_invoke(
        self, context: CallbackContext, input_data: Any, output_data: Any
    ) -> Any:
        """
        Execute all after_invoke callbacks for the given context.

        Args:
            context (CallbackContext): Context information about the invocation.
            input_data (Any): The original input data sent to the invocation.
            output_data (Any): The output data returned from the invocation.

        Returns:
            Any: The modified output data after all callbacks.

        Raises:
            Exception: If any callback fails.
        """
        current_output = output_data

        for callback in self._after_callbacks[context.invocation_type]:
            try:
                if isinstance(callback, AfterInvokeCallback):
                    current_output = await callback(context, input_data, current_output)
                elif callable(callback):
                    result = callback(context, input_data, current_output)
                    if hasattr(result, "__await__"):
                        current_output = await result
                    else:
                        current_output = result
            except Exception as e:
                await self.execute_on_error(context, input_data, e)
                raise

        return current_output

    async def execute_on_error(
        self, context: CallbackContext, input_data: Any, error: Exception
    ) -> Message | None:
        """
        Execute all on_error callbacks for the given context.

        Args:
            context (CallbackContext): Context information about the invocation.
            input_data (Any): The input data that caused the error.
            error (Exception): The exception that occurred.

        Returns:
            Message | None: Recovery value from callbacks, or None if not handled.
        """
        recovery_value = None

        for callback in self._error_callbacks[context.invocation_type]:
            try:
                result = None
                if isinstance(callback, OnErrorCallback):
                    result = await callback(context, input_data, error)
                elif callable(callback):
                    result = callback(context, input_data, error)
                    if hasattr(result, "__await__"):
                        result = await result  # type: ignore

                if isinstance(result, Message) or result is None:
                    recovery_value = result
            except Exception as exc:
                logger.exception("Error callback failed: %s", exc)
                continue

        return recovery_value

    def clear_callbacks(self, invocation_type: InvocationType | None = None) -> None:
        """
        Clear callbacks for a specific invocation type or all types.

        Args:
            invocation_type (InvocationType | None): The invocation type to clear, or None for all.
        """
        if invocation_type:
            self._before_callbacks[invocation_type].clear()
            self._after_callbacks[invocation_type].clear()
            self._error_callbacks[invocation_type].clear()
        else:
            for inv_type in InvocationType:
                self._before_callbacks[inv_type].clear()
                self._after_callbacks[inv_type].clear()
                self._error_callbacks[inv_type].clear()

    def get_callback_counts(self) -> dict[str, dict[str, int]]:
        """
        Get count of registered callbacks by type for debugging.

        Returns:
            dict[str, dict[str, int]]: Counts of callbacks for each invocation type.
        """
        return {
            inv_type.value: {
                "before_invoke": len(self._before_callbacks[inv_type]),
                "after_invoke": len(self._after_callbacks[inv_type]),
                "on_error": len(self._error_callbacks[inv_type]),
            }
            for inv_type in InvocationType
        }
Functions
__init__
__init__()

Initialize the CallbackManager with empty callback registries.

Source code in pyagenity/utils/callbacks.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def __init__(self):
    """
    Initialize the CallbackManager with empty callback registries.
    """
    self._before_callbacks: dict[InvocationType, list[BeforeInvokeCallbackType]] = {
        InvocationType.AI: [],
        InvocationType.TOOL: [],
        InvocationType.MCP: [],
    }
    self._after_callbacks: dict[InvocationType, list[AfterInvokeCallbackType]] = {
        InvocationType.AI: [],
        InvocationType.TOOL: [],
        InvocationType.MCP: [],
    }
    self._error_callbacks: dict[InvocationType, list[OnErrorCallbackType]] = {
        InvocationType.AI: [],
        InvocationType.TOOL: [],
        InvocationType.MCP: [],
    }
clear_callbacks
clear_callbacks(invocation_type=None)

Clear callbacks for a specific invocation type or all types.

Parameters:

Name Type Description Default
invocation_type InvocationType | None

The invocation type to clear, or None for all.

None
Source code in pyagenity/utils/callbacks.py
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
def clear_callbacks(self, invocation_type: InvocationType | None = None) -> None:
    """
    Clear callbacks for a specific invocation type or all types.

    Args:
        invocation_type (InvocationType | None): The invocation type to clear, or None for all.
    """
    if invocation_type:
        self._before_callbacks[invocation_type].clear()
        self._after_callbacks[invocation_type].clear()
        self._error_callbacks[invocation_type].clear()
    else:
        for inv_type in InvocationType:
            self._before_callbacks[inv_type].clear()
            self._after_callbacks[inv_type].clear()
            self._error_callbacks[inv_type].clear()
execute_after_invoke async
execute_after_invoke(context, input_data, output_data)

Execute all after_invoke callbacks for the given context.

Parameters:

Name Type Description Default
context CallbackContext

Context information about the invocation.

required
input_data Any

The original input data sent to the invocation.

required
output_data Any

The output data returned from the invocation.

required

Returns:

Name Type Description
Any Any

The modified output data after all callbacks.

Raises:

Type Description
Exception

If any callback fails.

Source code in pyagenity/utils/callbacks.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
async def execute_after_invoke(
    self, context: CallbackContext, input_data: Any, output_data: Any
) -> Any:
    """
    Execute all after_invoke callbacks for the given context.

    Args:
        context (CallbackContext): Context information about the invocation.
        input_data (Any): The original input data sent to the invocation.
        output_data (Any): The output data returned from the invocation.

    Returns:
        Any: The modified output data after all callbacks.

    Raises:
        Exception: If any callback fails.
    """
    current_output = output_data

    for callback in self._after_callbacks[context.invocation_type]:
        try:
            if isinstance(callback, AfterInvokeCallback):
                current_output = await callback(context, input_data, current_output)
            elif callable(callback):
                result = callback(context, input_data, current_output)
                if hasattr(result, "__await__"):
                    current_output = await result
                else:
                    current_output = result
        except Exception as e:
            await self.execute_on_error(context, input_data, e)
            raise

    return current_output
execute_before_invoke async
execute_before_invoke(context, input_data)

Execute all before_invoke callbacks for the given context.

Parameters:

Name Type Description Default
context CallbackContext

Context information about the invocation.

required
input_data Any

The input data to be validated or modified.

required

Returns:

Name Type Description
Any Any

The modified input data after all callbacks.

Raises:

Type Description
Exception

If any callback fails.

Source code in pyagenity/utils/callbacks.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
async def execute_before_invoke(self, context: CallbackContext, input_data: Any) -> Any:
    """
    Execute all before_invoke callbacks for the given context.

    Args:
        context (CallbackContext): Context information about the invocation.
        input_data (Any): The input data to be validated or modified.

    Returns:
        Any: The modified input data after all callbacks.

    Raises:
        Exception: If any callback fails.
    """
    current_data = input_data

    for callback in self._before_callbacks[context.invocation_type]:
        try:
            if isinstance(callback, BeforeInvokeCallback):
                current_data = await callback(context, current_data)
            elif callable(callback):
                result = callback(context, current_data)
                if hasattr(result, "__await__"):
                    current_data = await result
                else:
                    current_data = result
        except Exception as e:
            await self.execute_on_error(context, input_data, e)
            raise

    return current_data
execute_on_error async
execute_on_error(context, input_data, error)

Execute all on_error callbacks for the given context.

Parameters:

Name Type Description Default
context CallbackContext

Context information about the invocation.

required
input_data Any

The input data that caused the error.

required
error Exception

The exception that occurred.

required

Returns:

Type Description
Message | None

Message | None: Recovery value from callbacks, or None if not handled.

Source code in pyagenity/utils/callbacks.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
async def execute_on_error(
    self, context: CallbackContext, input_data: Any, error: Exception
) -> Message | None:
    """
    Execute all on_error callbacks for the given context.

    Args:
        context (CallbackContext): Context information about the invocation.
        input_data (Any): The input data that caused the error.
        error (Exception): The exception that occurred.

    Returns:
        Message | None: Recovery value from callbacks, or None if not handled.
    """
    recovery_value = None

    for callback in self._error_callbacks[context.invocation_type]:
        try:
            result = None
            if isinstance(callback, OnErrorCallback):
                result = await callback(context, input_data, error)
            elif callable(callback):
                result = callback(context, input_data, error)
                if hasattr(result, "__await__"):
                    result = await result  # type: ignore

            if isinstance(result, Message) or result is None:
                recovery_value = result
        except Exception as exc:
            logger.exception("Error callback failed: %s", exc)
            continue

    return recovery_value
get_callback_counts
get_callback_counts()

Get count of registered callbacks by type for debugging.

Returns:

Type Description
dict[str, dict[str, int]]

dict[str, dict[str, int]]: Counts of callbacks for each invocation type.

Source code in pyagenity/utils/callbacks.py
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
def get_callback_counts(self) -> dict[str, dict[str, int]]:
    """
    Get count of registered callbacks by type for debugging.

    Returns:
        dict[str, dict[str, int]]: Counts of callbacks for each invocation type.
    """
    return {
        inv_type.value: {
            "before_invoke": len(self._before_callbacks[inv_type]),
            "after_invoke": len(self._after_callbacks[inv_type]),
            "on_error": len(self._error_callbacks[inv_type]),
        }
        for inv_type in InvocationType
    }
register_after_invoke
register_after_invoke(invocation_type, callback)

Register an after_invoke callback for a specific invocation type.

Parameters:

Name Type Description Default
invocation_type InvocationType

The type of invocation (AI, TOOL, MCP).

required
callback AfterInvokeCallbackType

The callback to register.

required
Source code in pyagenity/utils/callbacks.py
176
177
178
179
180
181
182
183
184
185
186
def register_after_invoke(
    self, invocation_type: InvocationType, callback: AfterInvokeCallbackType
) -> None:
    """
    Register an after_invoke callback for a specific invocation type.

    Args:
        invocation_type (InvocationType): The type of invocation (AI, TOOL, MCP).
        callback (AfterInvokeCallbackType): The callback to register.
    """
    self._after_callbacks[invocation_type].append(callback)
register_before_invoke
register_before_invoke(invocation_type, callback)

Register a before_invoke callback for a specific invocation type.

Parameters:

Name Type Description Default
invocation_type InvocationType

The type of invocation (AI, TOOL, MCP).

required
callback BeforeInvokeCallbackType

The callback to register.

required
Source code in pyagenity/utils/callbacks.py
164
165
166
167
168
169
170
171
172
173
174
def register_before_invoke(
    self, invocation_type: InvocationType, callback: BeforeInvokeCallbackType
) -> None:
    """
    Register a before_invoke callback for a specific invocation type.

    Args:
        invocation_type (InvocationType): The type of invocation (AI, TOOL, MCP).
        callback (BeforeInvokeCallbackType): The callback to register.
    """
    self._before_callbacks[invocation_type].append(callback)
register_on_error
register_on_error(invocation_type, callback)

Register an on_error callback for a specific invocation type.

Parameters:

Name Type Description Default
invocation_type InvocationType

The type of invocation (AI, TOOL, MCP).

required
callback OnErrorCallbackType

The callback to register.

required
Source code in pyagenity/utils/callbacks.py
188
189
190
191
192
193
194
195
196
197
198
def register_on_error(
    self, invocation_type: InvocationType, callback: OnErrorCallbackType
) -> None:
    """
    Register an on_error callback for a specific invocation type.

    Args:
        invocation_type (InvocationType): The type of invocation (AI, TOOL, MCP).
        callback (OnErrorCallbackType): The callback to register.
    """
    self._error_callbacks[invocation_type].append(callback)
Command

Command object that combines state updates with control flow.

Allows nodes to update agent state and direct graph execution to specific nodes or graphs. Similar to LangGraph's Command API.

Methods:

Name Description
__init__

Initialize a Command object.

__repr__

Return a string representation of the Command object.

Attributes:

Name Type Description
PARENT
goto
graph
state
update
Source code in pyagenity/utils/command.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class Command[StateT: AgentState]:
    """
    Command object that combines state updates with control flow.

    Allows nodes to update agent state and direct graph execution to specific nodes or graphs.
    Similar to LangGraph's Command API.
    """

    PARENT = "PARENT"

    def __init__(
        self,
        update: Union["StateT", None, Message, str, "BaseConverter"] = None,
        goto: str | None = None,
        graph: str | None = None,
        state: StateT | None = None,
    ):
        """
        Initialize a Command object.

        Args:
            update (StateT | None | Message | str | BaseConverter): State update to apply.
            goto (str | None): Next node to execute (node name or END).
            graph (str | None): Which graph to navigate to (None for current, PARENT for parent).
            state (StateT | None): Optional agent state to attach.
        """
        self.update = update
        self.goto = goto
        self.graph = graph
        self.state = state

    def __repr__(self) -> str:
        """
        Return a string representation of the Command object.

        Returns:
            str: String representation of the Command.
        """
        return (
            f"Command(update={self.update}, goto={self.goto}, \n"
            f" graph={self.graph}, state={self.state})"
        )
Attributes
PARENT class-attribute instance-attribute
PARENT = 'PARENT'
goto instance-attribute
goto = goto
graph instance-attribute
graph = graph
state instance-attribute
state = state
update instance-attribute
update = update
Functions
__init__
__init__(update=None, goto=None, graph=None, state=None)

Initialize a Command object.

Parameters:

Name Type Description Default
update StateT | None | Message | str | BaseConverter

State update to apply.

None
goto str | None

Next node to execute (node name or END).

None
graph str | None

Which graph to navigate to (None for current, PARENT for parent).

None
state StateT | None

Optional agent state to attach.

None
Source code in pyagenity/utils/command.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def __init__(
    self,
    update: Union["StateT", None, Message, str, "BaseConverter"] = None,
    goto: str | None = None,
    graph: str | None = None,
    state: StateT | None = None,
):
    """
    Initialize a Command object.

    Args:
        update (StateT | None | Message | str | BaseConverter): State update to apply.
        goto (str | None): Next node to execute (node name or END).
        graph (str | None): Which graph to navigate to (None for current, PARENT for parent).
        state (StateT | None): Optional agent state to attach.
    """
    self.update = update
    self.goto = goto
    self.graph = graph
    self.state = state
__repr__
__repr__()

Return a string representation of the Command object.

Returns:

Name Type Description
str str

String representation of the Command.

Source code in pyagenity/utils/command.py
54
55
56
57
58
59
60
61
62
63
64
def __repr__(self) -> str:
    """
    Return a string representation of the Command object.

    Returns:
        str: String representation of the Command.
    """
    return (
        f"Command(update={self.update}, goto={self.goto}, \n"
        f" graph={self.graph}, state={self.state})"
    )
DataBlock

Bases: BaseModel

Data content block for messages.

Attributes:

Name Type Description
type Literal['data']

Block type discriminator.

mime_type str

MIME type of the data.

data_base64 str | None

Base64-encoded data.

media MediaRef | None

Reference to associated media.

Source code in pyagenity/utils/message.py
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
class DataBlock(BaseModel):
    """
    Data content block for messages.

    Attributes:
        type (Literal["data"]): Block type discriminator.
        mime_type (str): MIME type of the data.
        data_base64 (str | None): Base64-encoded data.
        media (MediaRef | None): Reference to associated media.
    """

    type: Literal["data"] = "data"
    mime_type: str
    data_base64: str | None = None
    media: MediaRef | None = None
Attributes
data_base64 class-attribute instance-attribute
data_base64 = None
media class-attribute instance-attribute
media = None
mime_type instance-attribute
mime_type
type class-attribute instance-attribute
type = 'data'
ErrorBlock

Bases: BaseModel

Error content block for messages.

Attributes:

Name Type Description
type Literal['error']

Block type discriminator.

message str

Error message.

code str | None

Error code.

data dict[str, Any] | None

Additional error data.

Source code in pyagenity/utils/message.py
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
class ErrorBlock(BaseModel):
    """
    Error content block for messages.

    Attributes:
        type (Literal["error"]): Block type discriminator.
        message (str): Error message.
        code (str | None): Error code.
        data (dict[str, Any] | None): Additional error data.
    """

    type: Literal["error"] = "error"
    message: str
    code: str | None = None
    data: dict[str, Any] | None = None
Attributes
code class-attribute instance-attribute
code = None
data class-attribute instance-attribute
data = None
message instance-attribute
message
type class-attribute instance-attribute
type = 'error'
ExecutionState

Bases: StrEnum

Graph execution states for agent workflows.

Values

RUNNING: Execution is in progress. PAUSED: Execution is paused. COMPLETED: Execution completed successfully. ERROR: Execution encountered an error. INTERRUPTED: Execution was interrupted. ABORTED: Execution was aborted. IDLE: Execution is idle.

Attributes:

Name Type Description
ABORTED
COMPLETED
ERROR
IDLE
INTERRUPTED
PAUSED
RUNNING
Source code in pyagenity/utils/constants.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class ExecutionState(StrEnum):
    """
    Graph execution states for agent workflows.

    Values:
        RUNNING: Execution is in progress.
        PAUSED: Execution is paused.
        COMPLETED: Execution completed successfully.
        ERROR: Execution encountered an error.
        INTERRUPTED: Execution was interrupted.
        ABORTED: Execution was aborted.
        IDLE: Execution is idle.
    """

    RUNNING = "running"
    PAUSED = "paused"
    COMPLETED = "completed"
    ERROR = "error"
    INTERRUPTED = "interrupted"
    ABORTED = "aborted"
    IDLE = "idle"
Attributes
ABORTED class-attribute instance-attribute
ABORTED = 'aborted'
COMPLETED class-attribute instance-attribute
COMPLETED = 'completed'
ERROR class-attribute instance-attribute
ERROR = 'error'
IDLE class-attribute instance-attribute
IDLE = 'idle'
INTERRUPTED class-attribute instance-attribute
INTERRUPTED = 'interrupted'
PAUSED class-attribute instance-attribute
PAUSED = 'paused'
RUNNING class-attribute instance-attribute
RUNNING = 'running'
InvocationType

Bases: Enum

Types of invocations that can trigger callbacks.

Attributes:

Name Type Description
AI
MCP
TOOL
Source code in pyagenity/utils/callbacks.py
28
29
30
31
32
33
class InvocationType(Enum):
    """Types of invocations that can trigger callbacks."""

    AI = "ai"
    TOOL = "tool"
    MCP = "mcp"
Attributes
AI class-attribute instance-attribute
AI = 'ai'
MCP class-attribute instance-attribute
MCP = 'mcp'
TOOL class-attribute instance-attribute
TOOL = 'tool'
Message

Bases: BaseModel

Represents a message in a conversation, including content, role, metadata, and token usage.

Attributes:

Name Type Description
message_id str | int

Unique identifier for the message.

role Literal['user', 'assistant', 'system', 'tool']

The role of the message sender.

content list[ContentBlock]

The message content blocks.

delta bool

Indicates if this is a delta/partial message.

tools_calls list[dict[str, Any]] | None

Tool call information, if any.

reasoning str | None

Reasoning or explanation, if any.

timestamp datetime | None

Timestamp of the message.

metadata dict[str, Any]

Additional metadata.

usages TokenUsages | None

Token usage statistics.

raw dict[str, Any] | None

Raw data, if any.

Example

msg = Message(message_id="abc123", role="user", content=[TextBlock(text="Hello!")])

Methods:

Name Description
attach_media

Append a media block to the content.

text

Best-effort text extraction from content blocks.

text_message

Create a Message instance from plain text.

tool_message

Create a tool message, optionally marking it as an error.

Source code in pyagenity/utils/message.py
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
class Message(BaseModel):
    """
    Represents a message in a conversation, including content, role, metadata, and token usage.

    Attributes:
        message_id (str | int): Unique identifier for the message.
        role (Literal["user", "assistant", "system", "tool"]): The role of the message sender.
        content (list[ContentBlock]): The message content blocks.
        delta (bool): Indicates if this is a delta/partial message.
        tools_calls (list[dict[str, Any]] | None): Tool call information, if any.
        reasoning (str | None): Reasoning or explanation, if any.
        timestamp (datetime | None): Timestamp of the message.
        metadata (dict[str, Any]): Additional metadata.
        usages (TokenUsages | None): Token usage statistics.
        raw (dict[str, Any] | None): Raw data, if any.

    Example:
        >>> msg = Message(message_id="abc123", role="user", content=[TextBlock(text="Hello!")])
        {'message_id': 'abc123', 'role': 'user', 'content': [...], ...}
    """

    message_id: str | int = Field(default_factory=lambda: generate_id(None))
    role: Literal["user", "assistant", "system", "tool"]
    content: list[ContentBlock]
    delta: bool = False  # Indicates if this is a delta/partial message
    tools_calls: list[dict[str, Any]] | None = None
    reasoning: str | None = None  # Remove it
    timestamp: datetime | None = Field(default_factory=datetime.now)
    metadata: dict[str, Any] = Field(default_factory=dict)
    usages: TokenUsages | None = None
    raw: dict[str, Any] | None = None

    @classmethod
    def text_message(
        cls,
        content: str,
        role: Literal["user", "assistant", "system", "tool"] = "user",
        message_id: str | None = None,
    ) -> "Message":
        """
        Create a Message instance from plain text.

        Args:
            content (str): The message content.
            role (Literal["user", "assistant", "system", "tool"]): The role of the sender.
            message_id (str | None): Optional message ID.

        Returns:
            Message: The created Message instance.

        Example:
            >>> Message.text_message("Hello!", role="user")
        """
        logger.debug("Creating message from text with role: %s", role)
        return cls(
            message_id=generate_id(message_id),
            role=role,
            content=[TextBlock(text=content)],
            timestamp=datetime.now(),
            metadata={},
        )

    @classmethod
    def tool_message(
        cls,
        content: list[ContentBlock],
        message_id: str | None = None,
        meta: dict[str, Any] | None = None,
    ) -> "Message":
        """
        Create a tool message, optionally marking it as an error.

        Args:
            content (list[ContentBlock]): The message content blocks.
            message_id (str | None): Optional message ID.
            meta (dict[str, Any] | None): Optional metadata.

        Returns:
            Message: The created tool message instance.

        Example:
            >>> Message.tool_message([ToolResultBlock(...)], message_id="tool1")
        """
        res = content
        msg_id = generate_id(message_id)
        return cls(
            message_id=msg_id,
            role="tool",
            content=res,
            timestamp=datetime.now(),
            metadata=meta or {},
        )

    # --- Convenience helpers ---
    def text(self) -> str:
        """
        Best-effort text extraction from content blocks.

        Returns:
            str: Concatenated text from TextBlock and ToolResultBlock outputs.

        Example:
            >>> msg.text()
            'Hello!Result text.'
        """
        parts: list[str] = []
        for block in self.content:
            if isinstance(block, TextBlock):
                parts.append(block.text)
            elif isinstance(block, ToolResultBlock) and isinstance(block.output, str):
                parts.append(block.output)
        return "".join(parts)

    def attach_media(
        self,
        media: MediaRef,
        as_type: Literal["image", "audio", "video", "document"],
    ) -> None:
        """
        Append a media block to the content.

        If content was text, creates a block list. Supports image, audio, video, and document types.

        Args:
            media (MediaRef): Reference to media content.
            as_type (Literal["image", "audio", "video", "document"]): Type of media block to append.

        Returns:
            None

        Raises:
            ValueError: If an unsupported media type is provided.

        Example:
            >>> msg.attach_media(media_ref, as_type="image")
        """
        block: ContentBlock
        if as_type == "image":
            block = ImageBlock(media=media)
        elif as_type == "audio":
            block = AudioBlock(media=media)
        elif as_type == "video":
            block = VideoBlock(media=media)
        elif as_type == "document":
            block = DocumentBlock(media=media)
        else:
            raise ValueError(f"Unsupported media type: {as_type}")

        if isinstance(self.content, str):
            self.content = [TextBlock(text=self.content), block]
        elif isinstance(self.content, list):
            self.content.append(block)
        else:
            self.content = [block]
Attributes
content instance-attribute
content
delta class-attribute instance-attribute
delta = False
message_id class-attribute instance-attribute
message_id = Field(default_factory=lambda: generate_id(None))
metadata class-attribute instance-attribute
metadata = Field(default_factory=dict)
raw class-attribute instance-attribute
raw = None
reasoning class-attribute instance-attribute
reasoning = None
role instance-attribute
role
timestamp class-attribute instance-attribute
timestamp = Field(default_factory=now)
tools_calls class-attribute instance-attribute
tools_calls = None
usages class-attribute instance-attribute
usages = None
Functions
attach_media
attach_media(media, as_type)

Append a media block to the content.

If content was text, creates a block list. Supports image, audio, video, and document types.

Parameters:

Name Type Description Default
media MediaRef

Reference to media content.

required
as_type Literal['image', 'audio', 'video', 'document']

Type of media block to append.

required

Returns:

Type Description
None

None

Raises:

Type Description
ValueError

If an unsupported media type is provided.

Example

msg.attach_media(media_ref, as_type="image")

Source code in pyagenity/utils/message.py
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
def attach_media(
    self,
    media: MediaRef,
    as_type: Literal["image", "audio", "video", "document"],
) -> None:
    """
    Append a media block to the content.

    If content was text, creates a block list. Supports image, audio, video, and document types.

    Args:
        media (MediaRef): Reference to media content.
        as_type (Literal["image", "audio", "video", "document"]): Type of media block to append.

    Returns:
        None

    Raises:
        ValueError: If an unsupported media type is provided.

    Example:
        >>> msg.attach_media(media_ref, as_type="image")
    """
    block: ContentBlock
    if as_type == "image":
        block = ImageBlock(media=media)
    elif as_type == "audio":
        block = AudioBlock(media=media)
    elif as_type == "video":
        block = VideoBlock(media=media)
    elif as_type == "document":
        block = DocumentBlock(media=media)
    else:
        raise ValueError(f"Unsupported media type: {as_type}")

    if isinstance(self.content, str):
        self.content = [TextBlock(text=self.content), block]
    elif isinstance(self.content, list):
        self.content.append(block)
    else:
        self.content = [block]
text
text()

Best-effort text extraction from content blocks.

Returns:

Name Type Description
str str

Concatenated text from TextBlock and ToolResultBlock outputs.

Example

msg.text() 'Hello!Result text.'

Source code in pyagenity/utils/message.py
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
def text(self) -> str:
    """
    Best-effort text extraction from content blocks.

    Returns:
        str: Concatenated text from TextBlock and ToolResultBlock outputs.

    Example:
        >>> msg.text()
        'Hello!Result text.'
    """
    parts: list[str] = []
    for block in self.content:
        if isinstance(block, TextBlock):
            parts.append(block.text)
        elif isinstance(block, ToolResultBlock) and isinstance(block.output, str):
            parts.append(block.output)
    return "".join(parts)
text_message classmethod
text_message(content, role='user', message_id=None)

Create a Message instance from plain text.

Parameters:

Name Type Description Default
content str

The message content.

required
role Literal['user', 'assistant', 'system', 'tool']

The role of the sender.

'user'
message_id str | None

Optional message ID.

None

Returns:

Name Type Description
Message Message

The created Message instance.

Example

Message.text_message("Hello!", role="user")

Source code in pyagenity/utils/message.py
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
@classmethod
def text_message(
    cls,
    content: str,
    role: Literal["user", "assistant", "system", "tool"] = "user",
    message_id: str | None = None,
) -> "Message":
    """
    Create a Message instance from plain text.

    Args:
        content (str): The message content.
        role (Literal["user", "assistant", "system", "tool"]): The role of the sender.
        message_id (str | None): Optional message ID.

    Returns:
        Message: The created Message instance.

    Example:
        >>> Message.text_message("Hello!", role="user")
    """
    logger.debug("Creating message from text with role: %s", role)
    return cls(
        message_id=generate_id(message_id),
        role=role,
        content=[TextBlock(text=content)],
        timestamp=datetime.now(),
        metadata={},
    )
tool_message classmethod
tool_message(content, message_id=None, meta=None)

Create a tool message, optionally marking it as an error.

Parameters:

Name Type Description Default
content list[ContentBlock]

The message content blocks.

required
message_id str | None

Optional message ID.

None
meta dict[str, Any] | None

Optional metadata.

None

Returns:

Name Type Description
Message Message

The created tool message instance.

Example

Message.tool_message([ToolResultBlock(...)], message_id="tool1")

Source code in pyagenity/utils/message.py
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
@classmethod
def tool_message(
    cls,
    content: list[ContentBlock],
    message_id: str | None = None,
    meta: dict[str, Any] | None = None,
) -> "Message":
    """
    Create a tool message, optionally marking it as an error.

    Args:
        content (list[ContentBlock]): The message content blocks.
        message_id (str | None): Optional message ID.
        meta (dict[str, Any] | None): Optional metadata.

    Returns:
        Message: The created tool message instance.

    Example:
        >>> Message.tool_message([ToolResultBlock(...)], message_id="tool1")
    """
    res = content
    msg_id = generate_id(message_id)
    return cls(
        message_id=msg_id,
        role="tool",
        content=res,
        timestamp=datetime.now(),
        metadata=meta or {},
    )
OnErrorCallback

Bases: ABC

Abstract base class for on_error callbacks.

Called when an error occurs during invocation. Allows for error handling and logging.

Methods:

Name Description
__call__

Execute the on_error callback.

Source code in pyagenity/utils/callbacks.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
class OnErrorCallback(ABC):
    """Abstract base class for on_error callbacks.

    Called when an error occurs during invocation.
    Allows for error handling and logging.
    """

    @abstractmethod
    async def __call__(
        self, context: CallbackContext, input_data: Any, error: Exception
    ) -> Any | None:
        """Execute the on_error callback.

        Args:
            context: Context information about the invocation
            input_data: The input data that caused the error
            error: The exception that occurred

        Returns:
            Optional recovery value or None to re-raise the error

        Raises:
            Exception: If error handling fails or if the error should be re-raised
        """
        ...
Functions
__call__ abstractmethod async
__call__(context, input_data, error)

Execute the on_error callback.

Parameters:

Name Type Description Default
context CallbackContext

Context information about the invocation

required
input_data Any

The input data that caused the error

required
error Exception

The exception that occurred

required

Returns:

Type Description
Any | None

Optional recovery value or None to re-raise the error

Raises:

Type Description
Exception

If error handling fails or if the error should be re-raised

Source code in pyagenity/utils/callbacks.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
@abstractmethod
async def __call__(
    self, context: CallbackContext, input_data: Any, error: Exception
) -> Any | None:
    """Execute the on_error callback.

    Args:
        context: Context information about the invocation
        input_data: The input data that caused the error
        error: The exception that occurred

    Returns:
        Optional recovery value or None to re-raise the error

    Raises:
        Exception: If error handling fails or if the error should be re-raised
    """
    ...
ReasoningBlock

Bases: BaseModel

Reasoning content block for messages.

Attributes:

Name Type Description
type Literal['reasoning']

Block type discriminator.

summary str

Summary of reasoning.

details list[str] | None

Detailed reasoning steps.

Source code in pyagenity/utils/message.py
314
315
316
317
318
319
320
321
322
323
324
325
326
class ReasoningBlock(BaseModel):
    """
    Reasoning content block for messages.

    Attributes:
        type (Literal["reasoning"]): Block type discriminator.
        summary (str): Summary of reasoning.
        details (list[str] | None): Detailed reasoning steps.
    """

    type: Literal["reasoning"] = "reasoning"
    summary: str
    details: list[str] | None = None
Attributes
details class-attribute instance-attribute
details = None
summary instance-attribute
summary
type class-attribute instance-attribute
type = 'reasoning'
ResponseGranularity

Bases: StrEnum

Response granularity options for agent graph outputs.

Values

FULL: State, latest messages. PARTIAL: Context, summary, latest messages. LOW: Only latest messages.

Attributes:

Name Type Description
FULL
LOW
PARTIAL
Source code in pyagenity/utils/constants.py
55
56
57
58
59
60
61
62
63
64
65
66
67
class ResponseGranularity(StrEnum):
    """
    Response granularity options for agent graph outputs.

    Values:
        FULL: State, latest messages.
        PARTIAL: Context, summary, latest messages.
        LOW: Only latest messages.
    """

    FULL = "full"
    PARTIAL = "partial"
    LOW = "low"
Attributes
FULL class-attribute instance-attribute
FULL = 'full'
LOW class-attribute instance-attribute
LOW = 'low'
PARTIAL class-attribute instance-attribute
PARTIAL = 'partial'
StorageLevel

Message storage levels for agent state persistence.

Attributes:

Name Type Description
ALL

Save everything including tool calls.

MEDIUM

Only AI and human messages.

LOW

Only first human and last AI message.

Source code in pyagenity/utils/constants.py
17
18
19
20
21
22
23
24
25
26
27
28
29
class StorageLevel:
    """
    Message storage levels for agent state persistence.

    Attributes:
        ALL: Save everything including tool calls.
        MEDIUM: Only AI and human messages.
        LOW: Only first human and last AI message.
    """

    ALL = "all"
    MEDIUM = "medium"
    LOW = "low"
Attributes
ALL class-attribute instance-attribute
ALL = 'all'
LOW class-attribute instance-attribute
LOW = 'low'
MEDIUM class-attribute instance-attribute
MEDIUM = 'medium'
TextBlock

Bases: BaseModel

Text content block for messages.

Attributes:

Name Type Description
type Literal['text']

Block type discriminator.

text str

Text content.

annotations list[AnnotationRef]

List of annotation references.

Source code in pyagenity/utils/message.py
176
177
178
179
180
181
182
183
184
185
186
187
188
class TextBlock(BaseModel):
    """
    Text content block for messages.

    Attributes:
        type (Literal["text"]): Block type discriminator.
        text (str): Text content.
        annotations (list[AnnotationRef]): List of annotation references.
    """

    type: Literal["text"] = "text"
    text: str
    annotations: list[AnnotationRef] = Field(default_factory=list)
Attributes
annotations class-attribute instance-attribute
annotations = Field(default_factory=list)
text instance-attribute
text
type class-attribute instance-attribute
type = 'text'
ThreadInfo

Bases: BaseModel

Metadata and status for a thread in agent execution.

Attributes:

Name Type Description
thread_id int | str

Unique identifier for the thread.

thread_name str | None

Optional name for the thread.

user_id int | str | None

Optional user identifier associated with the thread.

metadata dict[str, Any] | None

Optional metadata for the thread.

updated_at datetime | None

Timestamp of last update.

stop_requested bool

Whether a stop has been requested for the thread.

run_id str | None

Optional run identifier for the thread execution.

Example

ThreadInfo(thread_id=1, thread_name="main", user_id=42)

Source code in pyagenity/utils/thread_info.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class ThreadInfo(BaseModel):
    """
    Metadata and status for a thread in agent execution.

    Attributes:
        thread_id (int | str): Unique identifier for the thread.
        thread_name (str | None): Optional name for the thread.
        user_id (int | str | None): Optional user identifier associated with the thread.
        metadata (dict[str, Any] | None): Optional metadata for the thread.
        updated_at (datetime | None): Timestamp of last update.
        stop_requested (bool): Whether a stop has been requested for the thread.
        run_id (str | None): Optional run identifier for the thread execution.

    Example:
        >>> ThreadInfo(thread_id=1, thread_name="main", user_id=42)
    """

    thread_id: int | str
    thread_name: str | None = None
    user_id: int | str | None = None
    metadata: dict[str, Any] | None = None
    updated_at: datetime | None = None
    run_id: str | None = None
Attributes
metadata class-attribute instance-attribute
metadata = None
run_id class-attribute instance-attribute
run_id = None
thread_id instance-attribute
thread_id
thread_name class-attribute instance-attribute
thread_name = None
updated_at class-attribute instance-attribute
updated_at = None
user_id class-attribute instance-attribute
user_id = None
TokenUsages

Bases: BaseModel

Tracks token usage statistics for a message or model response.

Attributes:

Name Type Description
completion_tokens int

Number of completion tokens used.

prompt_tokens int

Number of prompt tokens used.

total_tokens int

Total tokens used.

reasoning_tokens int

Reasoning tokens used (optional).

cache_creation_input_tokens int

Cache creation input tokens (optional).

cache_read_input_tokens int

Cache read input tokens (optional).

image_tokens int | None

Image tokens for multimodal models (optional).

audio_tokens int | None

Audio tokens for multimodal models (optional).

Example

usage = TokenUsages(completion_tokens=10, prompt_tokens=20, total_tokens=30)

Source code in pyagenity/utils/message.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
class TokenUsages(BaseModel):
    """
    Tracks token usage statistics for a message or model response.

    Attributes:
        completion_tokens (int): Number of completion tokens used.
        prompt_tokens (int): Number of prompt tokens used.
        total_tokens (int): Total tokens used.
        reasoning_tokens (int): Reasoning tokens used (optional).
        cache_creation_input_tokens (int): Cache creation input tokens (optional).
        cache_read_input_tokens (int): Cache read input tokens (optional).
        image_tokens (int | None): Image tokens for multimodal models (optional).
        audio_tokens (int | None): Audio tokens for multimodal models (optional).

    Example:
        >>> usage = TokenUsages(completion_tokens=10, prompt_tokens=20, total_tokens=30)
        {'completion_tokens': 10, 'prompt_tokens': 20, 'total_tokens': 30, ...}
    """

    completion_tokens: int
    prompt_tokens: int
    total_tokens: int
    reasoning_tokens: int = 0
    cache_creation_input_tokens: int = 0
    cache_read_input_tokens: int = 0
    # Optional modality-specific usage fields for multimodal models
    image_tokens: int | None = 0
    audio_tokens: int | None = 0
Attributes
audio_tokens class-attribute instance-attribute
audio_tokens = 0
cache_creation_input_tokens class-attribute instance-attribute
cache_creation_input_tokens = 0
cache_read_input_tokens class-attribute instance-attribute
cache_read_input_tokens = 0
completion_tokens instance-attribute
completion_tokens
image_tokens class-attribute instance-attribute
image_tokens = 0
prompt_tokens instance-attribute
prompt_tokens
reasoning_tokens class-attribute instance-attribute
reasoning_tokens = 0
total_tokens instance-attribute
total_tokens
ToolCallBlock

Bases: BaseModel

Tool call content block for messages.

Attributes:

Name Type Description
type Literal['tool_call']

Block type discriminator.

id str

Tool call ID.

name str

Tool name.

args dict[str, Any]

Arguments for the tool call.

tool_type str | None

Type of tool (e.g., web_search, file_search).

Source code in pyagenity/utils/message.py
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
class ToolCallBlock(BaseModel):
    """
    Tool call content block for messages.

    Attributes:
        type (Literal["tool_call"]): Block type discriminator.
        id (str): Tool call ID.
        name (str): Tool name.
        args (dict[str, Any]): Arguments for the tool call.
        tool_type (str | None): Type of tool (e.g., web_search, file_search).
    """

    type: Literal["tool_call"] = "tool_call"
    id: str
    name: str
    args: dict[str, Any] = Field(default_factory=dict)
    tool_type: str | None = None  # e.g., web_search, file_search, computer_use
Attributes
args class-attribute instance-attribute
args = Field(default_factory=dict)
id instance-attribute
id
name instance-attribute
name
tool_type class-attribute instance-attribute
tool_type = None
type class-attribute instance-attribute
type = 'tool_call'
ToolResultBlock

Bases: BaseModel

Tool result content block for messages.

Attributes:

Name Type Description
type Literal['tool_result']

Block type discriminator.

call_id str

Tool call ID.

output Any

Output from the tool (str, dict, MediaRef, or list of blocks).

is_error bool

Whether the result is an error.

status Literal['completed', 'failed'] | None

Status of the tool call.

Source code in pyagenity/utils/message.py
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
class ToolResultBlock(BaseModel):
    """
    Tool result content block for messages.

    Attributes:
        type (Literal["tool_result"]): Block type discriminator.
        call_id (str): Tool call ID.
        output (Any): Output from the tool (str, dict, MediaRef, or list of blocks).
        is_error (bool): Whether the result is an error.
        status (Literal["completed", "failed"] | None): Status of the tool call.
    """

    type: Literal["tool_result"] = "tool_result"
    call_id: str
    output: Any = None  # str | dict | MediaRef | list[ContentBlock-like]
    is_error: bool = False
    status: Literal["completed", "failed"] | None = None
Attributes
call_id instance-attribute
call_id
is_error class-attribute instance-attribute
is_error = False
output class-attribute instance-attribute
output = None
status class-attribute instance-attribute
status = None
type class-attribute instance-attribute
type = 'tool_result'

Functions

add_messages
add_messages(left, right)

Adds messages to the list, avoiding duplicates by message_id.

Parameters:

Name Type Description Default
left
list[Message]

Existing list of messages.

required
right
list[Message]

New messages to add.

required

Returns:

Type Description
list[Message]

list[Message]: Combined list with unique messages.

Example

add_messages([msg1], [msg2, msg1]) [msg1, msg2]

Source code in pyagenity/utils/reducers.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def add_messages(left: list[Message], right: list[Message]) -> list[Message]:
    """
    Adds messages to the list, avoiding duplicates by message_id.

    Args:
        left (list[Message]): Existing list of messages.
        right (list[Message]): New messages to add.

    Returns:
        list[Message]: Combined list with unique messages.

    Example:
        >>> add_messages([msg1], [msg2, msg1])
        [msg1, msg2]
    """
    left_ids = {msg.message_id for msg in left}
    right = [msg for msg in right if msg.message_id not in left_ids]
    return left + right
append_items
append_items(left, right)

Appends items to a list, avoiding duplicates by item.id.

Parameters:

Name Type Description Default
left
list

Existing list of items (must have .id attribute).

required
right
list

New items to add.

required

Returns:

Name Type Description
list list

Combined list with unique items.

Example

append_items([item1], [item2, item1]) [item1, item2]

Source code in pyagenity/utils/reducers.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def append_items(left: list, right: list) -> list:
    """
    Appends items to a list, avoiding duplicates by item.id.

    Args:
        left (list): Existing list of items (must have .id attribute).
        right (list): New items to add.

    Returns:
        list: Combined list with unique items.

    Example:
        >>> append_items([item1], [item2, item1])
        [item1, item2]
    """
    left_ids = {item.id for item in left}
    right = [item for item in right if item.id not in left_ids]
    return left + right
call_sync_or_async async
call_sync_or_async(func, *args, **kwargs)

Call a function that may be sync or async, returning its result.

If the function is synchronous, it runs in a thread pool to avoid blocking the event loop. If the result is awaitable, it is awaited before returning.

Parameters:

Name Type Description Default
func
Callable[..., Any]

The function to call.

required
*args

Positional arguments for the function.

()
**kwargs

Keyword arguments for the function.

{}

Returns:

Name Type Description
Any Any

The result of the function call, awaited if necessary.

Source code in pyagenity/utils/callable_utils.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
async def call_sync_or_async(func: Callable[..., Any], *args, **kwargs) -> Any:
    """
    Call a function that may be sync or async, returning its result.

    If the function is synchronous, it runs in a thread pool to avoid blocking
    the event loop. If the result is awaitable, it is awaited before returning.

    Args:
        func (Callable[..., Any]): The function to call.
        *args: Positional arguments for the function.
        **kwargs: Keyword arguments for the function.

    Returns:
        Any: The result of the function call, awaited if necessary.
    """
    if _is_async_callable(func):
        return await func(*args, **kwargs)

    # Call sync function in a thread pool
    result = await asyncio.to_thread(func, *args, **kwargs)
    # If the result is awaitable, await it
    if inspect.isawaitable(result):
        return await result
    return result
convert_messages
convert_messages(system_prompts, state=None, extra_messages=None)

Convert system prompts, agent state, and extra messages to a list of dicts for LLM/tool payloads.

Parameters:

Name Type Description Default
system_prompts
list[dict[str, Any]]

List of system prompt dicts.

required
state
AgentState | None

Optional agent state containing context and summary.

None
extra_messages
list[Message] | None

Optional extra messages to include.

None

Returns:

Type Description
list[dict[str, Any]]

list[dict[str, Any]]: List of message dicts for payloads.

Raises:

Type Description
ValueError

If system_prompts is None.

Source code in pyagenity/utils/converter.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def convert_messages(
    system_prompts: list[dict[str, Any]],
    state: Union["AgentState", None] = None,
    extra_messages: list[Message] | None = None,
) -> list[dict[str, Any]]:
    """
    Convert system prompts, agent state, and extra messages to a list of dicts for
    LLM/tool payloads.

    Args:
        system_prompts (list[dict[str, Any]]): List of system prompt dicts.
        state (AgentState | None): Optional agent state containing context and summary.
        extra_messages (list[Message] | None): Optional extra messages to include.

    Returns:
        list[dict[str, Any]]: List of message dicts for payloads.

    Raises:
        ValueError: If system_prompts is None.
    """
    if system_prompts is None:
        logger.error("System prompts are None")
        raise ValueError("System prompts cannot be None")

    res = []
    res += system_prompts

    if state and state.context_summary:
        summary = {
            "role": "assistant",
            "content": state.context_summary if state.context_summary else "",
        }
        res.append(summary)

    if state and state.context:
        for msg in state.context:
            res.append(_convert_dict(msg))

    if extra_messages:
        for msg in extra_messages:
            res.append(_convert_dict(msg))

    logger.debug("Number of Converted messages: %s", len(res))
    return res
register_after_invoke
register_after_invoke(invocation_type, callback)

Register an after_invoke callback on the global callback manager.

Parameters:

Name Type Description Default
invocation_type
InvocationType

The type of invocation (AI, TOOL, MCP).

required
callback
AfterInvokeCallbackType

The callback to register.

required
Source code in pyagenity/utils/callbacks.py
353
354
355
356
357
358
359
360
361
362
363
def register_after_invoke(
    invocation_type: InvocationType, callback: AfterInvokeCallbackType
) -> None:
    """
    Register an after_invoke callback on the global callback manager.

    Args:
        invocation_type (InvocationType): The type of invocation (AI, TOOL, MCP).
        callback (AfterInvokeCallbackType): The callback to register.
    """
    default_callback_manager.register_after_invoke(invocation_type, callback)
register_before_invoke
register_before_invoke(invocation_type, callback)

Register a before_invoke callback on the global callback manager.

Parameters:

Name Type Description Default
invocation_type
InvocationType

The type of invocation (AI, TOOL, MCP).

required
callback
BeforeInvokeCallbackType

The callback to register.

required
Source code in pyagenity/utils/callbacks.py
340
341
342
343
344
345
346
347
348
349
350
def register_before_invoke(
    invocation_type: InvocationType, callback: BeforeInvokeCallbackType
) -> None:
    """
    Register a before_invoke callback on the global callback manager.

    Args:
        invocation_type (InvocationType): The type of invocation (AI, TOOL, MCP).
        callback (BeforeInvokeCallbackType): The callback to register.
    """
    default_callback_manager.register_before_invoke(invocation_type, callback)
register_on_error
register_on_error(invocation_type, callback)

Register an on_error callback on the global callback manager.

Parameters:

Name Type Description Default
invocation_type
InvocationType

The type of invocation (AI, TOOL, MCP).

required
callback
OnErrorCallbackType

The callback to register.

required
Source code in pyagenity/utils/callbacks.py
366
367
368
369
370
371
372
373
374
def register_on_error(invocation_type: InvocationType, callback: OnErrorCallbackType) -> None:
    """
    Register an on_error callback on the global callback manager.

    Args:
        invocation_type (InvocationType): The type of invocation (AI, TOOL, MCP).
        callback (OnErrorCallbackType): The callback to register.
    """
    default_callback_manager.register_on_error(invocation_type, callback)
replace_messages
replace_messages(left, right)

Replaces the entire message list with a new one.

Parameters:

Name Type Description Default
left
list[Message]

Existing list of messages (ignored).

required
right
list[Message]

New list of messages.

required

Returns:

Type Description
list[Message]

list[Message]: The new message list.

Example

replace_messages([msg1], [msg2]) [msg2]

Source code in pyagenity/utils/reducers.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def replace_messages(left: list[Message], right: list[Message]) -> list[Message]:
    """
    Replaces the entire message list with a new one.

    Args:
        left (list[Message]): Existing list of messages (ignored).
        right (list[Message]): New list of messages.

    Returns:
        list[Message]: The new message list.

    Example:
        >>> replace_messages([msg1], [msg2])
        [msg2]
    """
    return right
replace_value
replace_value(left, right)

Replaces a value with another.

Parameters:

Name Type Description Default
left

Existing value (ignored).

required
right

New value to use.

required

Returns:

Name Type Description
Any

The new value.

Example

replace_value(1, 2) 2

Source code in pyagenity/utils/reducers.py
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def replace_value(left, right):
    """
    Replaces a value with another.

    Args:
        left: Existing value (ignored).
        right: New value to use.

    Returns:
        Any: The new value.

    Example:
        >>> replace_value(1, 2)
        2
    """
    return right
run_coroutine
run_coroutine(func)

Run an async coroutine from a sync context safely.

Source code in pyagenity/utils/callable_utils.py
54
55
56
57
58
59
60
61
62
63
64
65
def run_coroutine(func: Coroutine) -> Any:
    """Run an async coroutine from a sync context safely."""
    # Always try to get/create an event loop and use thread-safe execution
    try:
        loop = asyncio.get_running_loop()
    except RuntimeError:
        # No loop running, create one
        return asyncio.run(func)

    # Loop is running, use thread-safe execution
    fut = asyncio.run_coroutine_threadsafe(func, loop)
    return fut.result()

Modules

background_task_manager

Background task manager for async operations in PyAgenity.

This module provides BackgroundTaskManager, which tracks and manages asyncio background tasks, ensuring proper cleanup and error logging.

Classes:

Name Description
BackgroundTaskManager

Manages asyncio background tasks for agent operations.

TaskMetadata

Metadata for tracking background tasks.

Attributes:

Name Type Description
logger
Attributes
logger module-attribute
logger = getLogger(__name__)
Classes
BackgroundTaskManager

Manages asyncio background tasks for agent operations.

Tracks created tasks, ensures cleanup, and logs errors from background execution. Enhanced with cancellation, timeouts, and metadata tracking.

Methods:

Name Description
__init__

Initialize the BackgroundTaskManager.

cancel_all

Cancel all tracked background tasks.

create_task

Create and track a background asyncio task.

get_task_count

Get the number of active background tasks.

get_task_info

Get information about all active tasks.

wait_for_all

Wait for all tracked background tasks to complete.

Source code in pyagenity/utils/background_task_manager.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
class BackgroundTaskManager:
    """
    Manages asyncio background tasks for agent operations.

    Tracks created tasks, ensures cleanup, and logs errors from background execution.
    Enhanced with cancellation, timeouts, and metadata tracking.
    """

    def __init__(self):
        """
        Initialize the BackgroundTaskManager.
        """
        self._tasks: set[asyncio.Task] = set()
        self._task_metadata: dict[asyncio.Task, TaskMetadata] = {}

    def create_task(
        self,
        coro: Coroutine,
        *,
        name: str = "background_task",
        timeout: float | None = None,
        context: dict[str, Any] | None = None,
    ) -> asyncio.Task:
        """
        Create and track a background asyncio task.

        Args:
            coro (Coroutine): The coroutine to run in the background.
            name (str): Human-readable name for the task.
            timeout (Optional[float]): Timeout in seconds for the task.
            context (Optional[dict]): Additional context for logging.

        Returns:
            asyncio.Task: The created task.
        """
        metrics.counter("background_task_manager.tasks_created").inc()

        task = asyncio.create_task(coro, name=name)
        metadata = TaskMetadata(
            name=name, created_at=time.time(), timeout=timeout, context=context or {}
        )

        self._tasks.add(task)
        self._task_metadata[task] = metadata
        task.add_done_callback(self._task_done_callback)

        # Set up timeout if specified
        if timeout:
            self._setup_timeout(task, timeout)

        logger.debug(
            "Created background task: %s (timeout=%s)",
            name,
            timeout,
            extra={"task_context": context},
        )

        return task

    def _setup_timeout(self, task: asyncio.Task, timeout: float) -> None:
        """Set up timeout cancellation for a task."""

        async def timeout_canceller():
            try:
                await asyncio.sleep(timeout)
                if not task.done():
                    metadata = self._task_metadata.get(task)
                    task_name = metadata.name if metadata else "unknown"
                    logger.warning(
                        "Background task '%s' timed out after %s seconds", task_name, timeout
                    )
                    task.cancel()
                    metrics.counter("background_task_manager.tasks_timed_out").inc()
            except asyncio.CancelledError:
                pass  # Parent task was cancelled, this is expected

        # Create the timeout task but don't track it (avoid recursive tracking)
        timeout_task = asyncio.create_task(timeout_canceller())
        # Add a callback to clean up the timeout task reference
        timeout_task.add_done_callback(lambda t: None)

    def _task_done_callback(self, task: asyncio.Task) -> None:
        """
        Remove completed task and log exceptions if any.

        Args:
            task (asyncio.Task): The completed asyncio task.
        """
        metadata = self._task_metadata.pop(task, None)
        self._tasks.discard(task)

        task_name = metadata.name if metadata else "unknown"
        duration = time.time() - metadata.created_at if metadata else 0.0

        try:
            task.result()  # raises if task failed
            metrics.counter("background_task_manager.tasks_completed").inc()
            logger.debug(
                "Background task '%s' completed successfully (duration=%.2fs)",
                task_name,
                duration,
                extra={"task_context": metadata.context if metadata else {}},
            )
        except asyncio.CancelledError:
            metrics.counter("background_task_manager.tasks_cancelled").inc()
            logger.debug("Background task '%s' was cancelled", task_name)
        except Exception as e:
            metrics.counter("background_task_manager.tasks_failed").inc()
            error_msg = (
                f"Background task raised an exception - {task_name}: {e} (duration={duration:.2f}s)"
            )
            logger.error(
                error_msg,
                exc_info=e,
                extra={"task_context": metadata.context if metadata else {}},
            )

    async def cancel_all(self) -> None:
        """
        Cancel all tracked background tasks.

        Returns:
            None
        """
        if not self._tasks:
            return

        logger.info("Cancelling %d background tasks...", len(self._tasks))

        for task in self._tasks.copy():
            if not task.done():
                task.cancel()

        # Wait a short time for cancellations to process
        await asyncio.sleep(0.1)

    async def wait_for_all(
        self, timeout: float | None = None, return_exceptions: bool = False
    ) -> None:
        """
        Wait for all tracked background tasks to complete.

        Args:
            timeout (float | None): Maximum time to wait in seconds.
            return_exceptions (bool): If True, exceptions are returned as results instead of raised.

        Returns:
            None
        """
        if not self._tasks:
            return

        logger.info("Waiting for %d background tasks to finish...", len(self._tasks))

        try:
            if timeout:
                await asyncio.wait_for(
                    asyncio.gather(*self._tasks, return_exceptions=return_exceptions),
                    timeout=timeout,
                )
            else:
                await asyncio.gather(*self._tasks, return_exceptions=return_exceptions)
            logger.info("All background tasks finished.")
        except TimeoutError:
            logger.warning("Timeout waiting for background tasks, some may still be running")
            metrics.counter("background_task_manager.wait_timeout").inc()

    def get_task_count(self) -> int:
        """Get the number of active background tasks."""
        return len(self._tasks)

    def get_task_info(self) -> list[dict[str, Any]]:
        """Get information about all active tasks."""
        current_time = time.time()
        return [
            {
                "name": metadata.name,
                "age_seconds": current_time - metadata.created_at,
                "timeout": metadata.timeout,
                "context": metadata.context,
                "done": task.done(),
                "cancelled": task.cancelled() if task.done() else False,
            }
            for task, metadata in self._task_metadata.items()
        ]
Functions
__init__
__init__()

Initialize the BackgroundTaskManager.

Source code in pyagenity/utils/background_task_manager.py
39
40
41
42
43
44
def __init__(self):
    """
    Initialize the BackgroundTaskManager.
    """
    self._tasks: set[asyncio.Task] = set()
    self._task_metadata: dict[asyncio.Task, TaskMetadata] = {}
cancel_all async
cancel_all()

Cancel all tracked background tasks.

Returns:

Type Description
None

None

Source code in pyagenity/utils/background_task_manager.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
async def cancel_all(self) -> None:
    """
    Cancel all tracked background tasks.

    Returns:
        None
    """
    if not self._tasks:
        return

    logger.info("Cancelling %d background tasks...", len(self._tasks))

    for task in self._tasks.copy():
        if not task.done():
            task.cancel()

    # Wait a short time for cancellations to process
    await asyncio.sleep(0.1)
create_task
create_task(coro, *, name='background_task', timeout=None, context=None)

Create and track a background asyncio task.

Parameters:

Name Type Description Default
coro Coroutine

The coroutine to run in the background.

required
name str

Human-readable name for the task.

'background_task'
timeout Optional[float]

Timeout in seconds for the task.

None
context Optional[dict]

Additional context for logging.

None

Returns:

Type Description
Task

asyncio.Task: The created task.

Source code in pyagenity/utils/background_task_manager.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def create_task(
    self,
    coro: Coroutine,
    *,
    name: str = "background_task",
    timeout: float | None = None,
    context: dict[str, Any] | None = None,
) -> asyncio.Task:
    """
    Create and track a background asyncio task.

    Args:
        coro (Coroutine): The coroutine to run in the background.
        name (str): Human-readable name for the task.
        timeout (Optional[float]): Timeout in seconds for the task.
        context (Optional[dict]): Additional context for logging.

    Returns:
        asyncio.Task: The created task.
    """
    metrics.counter("background_task_manager.tasks_created").inc()

    task = asyncio.create_task(coro, name=name)
    metadata = TaskMetadata(
        name=name, created_at=time.time(), timeout=timeout, context=context or {}
    )

    self._tasks.add(task)
    self._task_metadata[task] = metadata
    task.add_done_callback(self._task_done_callback)

    # Set up timeout if specified
    if timeout:
        self._setup_timeout(task, timeout)

    logger.debug(
        "Created background task: %s (timeout=%s)",
        name,
        timeout,
        extra={"task_context": context},
    )

    return task
get_task_count
get_task_count()

Get the number of active background tasks.

Source code in pyagenity/utils/background_task_manager.py
198
199
200
def get_task_count(self) -> int:
    """Get the number of active background tasks."""
    return len(self._tasks)
get_task_info
get_task_info()

Get information about all active tasks.

Source code in pyagenity/utils/background_task_manager.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def get_task_info(self) -> list[dict[str, Any]]:
    """Get information about all active tasks."""
    current_time = time.time()
    return [
        {
            "name": metadata.name,
            "age_seconds": current_time - metadata.created_at,
            "timeout": metadata.timeout,
            "context": metadata.context,
            "done": task.done(),
            "cancelled": task.cancelled() if task.done() else False,
        }
        for task, metadata in self._task_metadata.items()
    ]
wait_for_all async
wait_for_all(timeout=None, return_exceptions=False)

Wait for all tracked background tasks to complete.

Parameters:

Name Type Description Default
timeout float | None

Maximum time to wait in seconds.

None
return_exceptions bool

If True, exceptions are returned as results instead of raised.

False

Returns:

Type Description
None

None

Source code in pyagenity/utils/background_task_manager.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
async def wait_for_all(
    self, timeout: float | None = None, return_exceptions: bool = False
) -> None:
    """
    Wait for all tracked background tasks to complete.

    Args:
        timeout (float | None): Maximum time to wait in seconds.
        return_exceptions (bool): If True, exceptions are returned as results instead of raised.

    Returns:
        None
    """
    if not self._tasks:
        return

    logger.info("Waiting for %d background tasks to finish...", len(self._tasks))

    try:
        if timeout:
            await asyncio.wait_for(
                asyncio.gather(*self._tasks, return_exceptions=return_exceptions),
                timeout=timeout,
            )
        else:
            await asyncio.gather(*self._tasks, return_exceptions=return_exceptions)
        logger.info("All background tasks finished.")
    except TimeoutError:
        logger.warning("Timeout waiting for background tasks, some may still be running")
        metrics.counter("background_task_manager.wait_timeout").inc()
TaskMetadata dataclass

Metadata for tracking background tasks.

Methods:

Name Description
__init__

Attributes:

Name Type Description
context dict[str, Any] | None
created_at float
name str
timeout float | None
Source code in pyagenity/utils/background_task_manager.py
21
22
23
24
25
26
27
28
@dataclass
class TaskMetadata:
    """Metadata for tracking background tasks."""

    name: str
    created_at: float
    timeout: float | None = None
    context: dict[str, Any] | None = None
Attributes
context class-attribute instance-attribute
context = None
created_at instance-attribute
created_at
name instance-attribute
name
timeout class-attribute instance-attribute
timeout = None
Functions
__init__
__init__(name, created_at, timeout=None, context=None)
Modules
callable_utils

Utilities for calling sync or async functions in PyAgenity.

This module provides helpers to detect async callables and to invoke functions that may be synchronous or asynchronous, handling thread pool execution and awaitables.

Functions:

Name Description
call_sync_or_async

Call a function that may be sync or async, returning its result.

run_coroutine

Run an async coroutine from a sync context safely.

Functions
call_sync_or_async async
call_sync_or_async(func, *args, **kwargs)

Call a function that may be sync or async, returning its result.

If the function is synchronous, it runs in a thread pool to avoid blocking the event loop. If the result is awaitable, it is awaited before returning.

Parameters:

Name Type Description Default
func Callable[..., Any]

The function to call.

required
*args

Positional arguments for the function.

()
**kwargs

Keyword arguments for the function.

{}

Returns:

Name Type Description
Any Any

The result of the function call, awaited if necessary.

Source code in pyagenity/utils/callable_utils.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
async def call_sync_or_async(func: Callable[..., Any], *args, **kwargs) -> Any:
    """
    Call a function that may be sync or async, returning its result.

    If the function is synchronous, it runs in a thread pool to avoid blocking
    the event loop. If the result is awaitable, it is awaited before returning.

    Args:
        func (Callable[..., Any]): The function to call.
        *args: Positional arguments for the function.
        **kwargs: Keyword arguments for the function.

    Returns:
        Any: The result of the function call, awaited if necessary.
    """
    if _is_async_callable(func):
        return await func(*args, **kwargs)

    # Call sync function in a thread pool
    result = await asyncio.to_thread(func, *args, **kwargs)
    # If the result is awaitable, await it
    if inspect.isawaitable(result):
        return await result
    return result
run_coroutine
run_coroutine(func)

Run an async coroutine from a sync context safely.

Source code in pyagenity/utils/callable_utils.py
54
55
56
57
58
59
60
61
62
63
64
65
def run_coroutine(func: Coroutine) -> Any:
    """Run an async coroutine from a sync context safely."""
    # Always try to get/create an event loop and use thread-safe execution
    try:
        loop = asyncio.get_running_loop()
    except RuntimeError:
        # No loop running, create one
        return asyncio.run(func)

    # Loop is running, use thread-safe execution
    fut = asyncio.run_coroutine_threadsafe(func, loop)
    return fut.result()
callbacks

Callback system for PyAgenity.

This module provides a comprehensive callback framework that allows users to define their own validation logic and custom behavior at key points in the execution flow:

  • before_invoke: Called before AI/TOOL/MCP invocation for input validation and modification
  • after_invoke: Called after AI/TOOL/MCP invocation for output validation and modification
  • on_error: Called when errors occur during invocation for error handling and logging

The system is generic and type-safe, supporting different callback types for different invocation contexts.

Classes:

Name Description
AfterInvokeCallback

Abstract base class for after_invoke callbacks.

BeforeInvokeCallback

Abstract base class for before_invoke callbacks.

CallbackContext

Context information passed to callbacks.

CallbackManager

Manages registration and execution of callbacks for different invocation types.

InvocationType

Types of invocations that can trigger callbacks.

OnErrorCallback

Abstract base class for on_error callbacks.

Functions:

Name Description
register_after_invoke

Register an after_invoke callback on the global callback manager.

register_before_invoke

Register a before_invoke callback on the global callback manager.

register_on_error

Register an on_error callback on the global callback manager.

Attributes:

Name Type Description
AfterInvokeCallbackType
BeforeInvokeCallbackType
OnErrorCallbackType
default_callback_manager
logger
Attributes
AfterInvokeCallbackType module-attribute
AfterInvokeCallbackType = Union[AfterInvokeCallback[Any, Any], Callable[[CallbackContext, Any, Any], Union[Any, Awaitable[Any]]]]
BeforeInvokeCallbackType module-attribute
BeforeInvokeCallbackType = Union[BeforeInvokeCallback[Any, Any], Callable[[CallbackContext, Any], Union[Any, Awaitable[Any]]]]
OnErrorCallbackType module-attribute
OnErrorCallbackType = Union[OnErrorCallback, Callable[[CallbackContext, Any, Exception], Union[Any | None, Awaitable[Any | None]]]]
default_callback_manager module-attribute
default_callback_manager = CallbackManager()
logger module-attribute
logger = getLogger(__name__)
Classes
AfterInvokeCallback

Bases: ABC

Abstract base class for after_invoke callbacks.

Called after the AI model, tool, or MCP function is invoked. Allows for output validation and modification.

Methods:

Name Description
__call__

Execute the after_invoke callback.

Source code in pyagenity/utils/callbacks.py
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
class AfterInvokeCallback[T, R](ABC):
    """Abstract base class for after_invoke callbacks.

    Called after the AI model, tool, or MCP function is invoked.
    Allows for output validation and modification.
    """

    @abstractmethod
    async def __call__(self, context: CallbackContext, input_data: T, output_data: Any) -> Any | R:
        """Execute the after_invoke callback.

        Args:
            context: Context information about the invocation
            input_data: The original input data that was sent
            output_data: The output data returned from the invocation

        Returns:
            Modified output data (can be same type or different type)

        Raises:
            Exception: If validation fails or modification cannot be performed
        """
        ...
Functions
__call__ abstractmethod async
__call__(context, input_data, output_data)

Execute the after_invoke callback.

Parameters:

Name Type Description Default
context CallbackContext

Context information about the invocation

required
input_data T

The original input data that was sent

required
output_data Any

The output data returned from the invocation

required

Returns:

Type Description
Any | R

Modified output data (can be same type or different type)

Raises:

Type Description
Exception

If validation fails or modification cannot be performed

Source code in pyagenity/utils/callbacks.py
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
@abstractmethod
async def __call__(self, context: CallbackContext, input_data: T, output_data: Any) -> Any | R:
    """Execute the after_invoke callback.

    Args:
        context: Context information about the invocation
        input_data: The original input data that was sent
        output_data: The output data returned from the invocation

    Returns:
        Modified output data (can be same type or different type)

    Raises:
        Exception: If validation fails or modification cannot be performed
    """
    ...
BeforeInvokeCallback

Bases: ABC

Abstract base class for before_invoke callbacks.

Called before the AI model, tool, or MCP function is invoked. Allows for input validation and modification.

Methods:

Name Description
__call__

Execute the before_invoke callback.

Source code in pyagenity/utils/callbacks.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class BeforeInvokeCallback[T, R](ABC):
    """Abstract base class for before_invoke callbacks.

    Called before the AI model, tool, or MCP function is invoked.
    Allows for input validation and modification.
    """

    @abstractmethod
    async def __call__(self, context: CallbackContext, input_data: T) -> T | R:
        """Execute the before_invoke callback.

        Args:
            context: Context information about the invocation
            input_data: The input data about to be sent to the invocation

        Returns:
            Modified input data (can be same type or different type)

        Raises:
            Exception: If validation fails or modification cannot be performed
        """
        ...
Functions
__call__ abstractmethod async
__call__(context, input_data)

Execute the before_invoke callback.

Parameters:

Name Type Description Default
context CallbackContext

Context information about the invocation

required
input_data T

The input data about to be sent to the invocation

required

Returns:

Type Description
T | R

Modified input data (can be same type or different type)

Raises:

Type Description
Exception

If validation fails or modification cannot be performed

Source code in pyagenity/utils/callbacks.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
@abstractmethod
async def __call__(self, context: CallbackContext, input_data: T) -> T | R:
    """Execute the before_invoke callback.

    Args:
        context: Context information about the invocation
        input_data: The input data about to be sent to the invocation

    Returns:
        Modified input data (can be same type or different type)

    Raises:
        Exception: If validation fails or modification cannot be performed
    """
    ...
CallbackContext dataclass

Context information passed to callbacks.

Methods:

Name Description
__init__

Attributes:

Name Type Description
function_name str | None
invocation_type InvocationType
metadata dict[str, Any] | None
node_name str
Source code in pyagenity/utils/callbacks.py
36
37
38
39
40
41
42
43
@dataclass
class CallbackContext:
    """Context information passed to callbacks."""

    invocation_type: InvocationType
    node_name: str
    function_name: str | None = None
    metadata: dict[str, Any] | None = None
Attributes
function_name class-attribute instance-attribute
function_name = None
invocation_type instance-attribute
invocation_type
metadata class-attribute instance-attribute
metadata = None
node_name instance-attribute
node_name
Functions
__init__
__init__(invocation_type, node_name, function_name=None, metadata=None)
CallbackManager

Manages registration and execution of callbacks for different invocation types.

Handles before_invoke, after_invoke, and on_error callbacks for AI, TOOL, and MCP invocations.

Methods:

Name Description
__init__

Initialize the CallbackManager with empty callback registries.

clear_callbacks

Clear callbacks for a specific invocation type or all types.

execute_after_invoke

Execute all after_invoke callbacks for the given context.

execute_before_invoke

Execute all before_invoke callbacks for the given context.

execute_on_error

Execute all on_error callbacks for the given context.

get_callback_counts

Get count of registered callbacks by type for debugging.

register_after_invoke

Register an after_invoke callback for a specific invocation type.

register_before_invoke

Register a before_invoke callback for a specific invocation type.

register_on_error

Register an on_error callback for a specific invocation type.

Source code in pyagenity/utils/callbacks.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
class CallbackManager:
    """
    Manages registration and execution of callbacks for different invocation types.

    Handles before_invoke, after_invoke, and on_error callbacks for AI, TOOL, and MCP invocations.
    """

    def __init__(self):
        """
        Initialize the CallbackManager with empty callback registries.
        """
        self._before_callbacks: dict[InvocationType, list[BeforeInvokeCallbackType]] = {
            InvocationType.AI: [],
            InvocationType.TOOL: [],
            InvocationType.MCP: [],
        }
        self._after_callbacks: dict[InvocationType, list[AfterInvokeCallbackType]] = {
            InvocationType.AI: [],
            InvocationType.TOOL: [],
            InvocationType.MCP: [],
        }
        self._error_callbacks: dict[InvocationType, list[OnErrorCallbackType]] = {
            InvocationType.AI: [],
            InvocationType.TOOL: [],
            InvocationType.MCP: [],
        }

    def register_before_invoke(
        self, invocation_type: InvocationType, callback: BeforeInvokeCallbackType
    ) -> None:
        """
        Register a before_invoke callback for a specific invocation type.

        Args:
            invocation_type (InvocationType): The type of invocation (AI, TOOL, MCP).
            callback (BeforeInvokeCallbackType): The callback to register.
        """
        self._before_callbacks[invocation_type].append(callback)

    def register_after_invoke(
        self, invocation_type: InvocationType, callback: AfterInvokeCallbackType
    ) -> None:
        """
        Register an after_invoke callback for a specific invocation type.

        Args:
            invocation_type (InvocationType): The type of invocation (AI, TOOL, MCP).
            callback (AfterInvokeCallbackType): The callback to register.
        """
        self._after_callbacks[invocation_type].append(callback)

    def register_on_error(
        self, invocation_type: InvocationType, callback: OnErrorCallbackType
    ) -> None:
        """
        Register an on_error callback for a specific invocation type.

        Args:
            invocation_type (InvocationType): The type of invocation (AI, TOOL, MCP).
            callback (OnErrorCallbackType): The callback to register.
        """
        self._error_callbacks[invocation_type].append(callback)

    async def execute_before_invoke(self, context: CallbackContext, input_data: Any) -> Any:
        """
        Execute all before_invoke callbacks for the given context.

        Args:
            context (CallbackContext): Context information about the invocation.
            input_data (Any): The input data to be validated or modified.

        Returns:
            Any: The modified input data after all callbacks.

        Raises:
            Exception: If any callback fails.
        """
        current_data = input_data

        for callback in self._before_callbacks[context.invocation_type]:
            try:
                if isinstance(callback, BeforeInvokeCallback):
                    current_data = await callback(context, current_data)
                elif callable(callback):
                    result = callback(context, current_data)
                    if hasattr(result, "__await__"):
                        current_data = await result
                    else:
                        current_data = result
            except Exception as e:
                await self.execute_on_error(context, input_data, e)
                raise

        return current_data

    async def execute_after_invoke(
        self, context: CallbackContext, input_data: Any, output_data: Any
    ) -> Any:
        """
        Execute all after_invoke callbacks for the given context.

        Args:
            context (CallbackContext): Context information about the invocation.
            input_data (Any): The original input data sent to the invocation.
            output_data (Any): The output data returned from the invocation.

        Returns:
            Any: The modified output data after all callbacks.

        Raises:
            Exception: If any callback fails.
        """
        current_output = output_data

        for callback in self._after_callbacks[context.invocation_type]:
            try:
                if isinstance(callback, AfterInvokeCallback):
                    current_output = await callback(context, input_data, current_output)
                elif callable(callback):
                    result = callback(context, input_data, current_output)
                    if hasattr(result, "__await__"):
                        current_output = await result
                    else:
                        current_output = result
            except Exception as e:
                await self.execute_on_error(context, input_data, e)
                raise

        return current_output

    async def execute_on_error(
        self, context: CallbackContext, input_data: Any, error: Exception
    ) -> Message | None:
        """
        Execute all on_error callbacks for the given context.

        Args:
            context (CallbackContext): Context information about the invocation.
            input_data (Any): The input data that caused the error.
            error (Exception): The exception that occurred.

        Returns:
            Message | None: Recovery value from callbacks, or None if not handled.
        """
        recovery_value = None

        for callback in self._error_callbacks[context.invocation_type]:
            try:
                result = None
                if isinstance(callback, OnErrorCallback):
                    result = await callback(context, input_data, error)
                elif callable(callback):
                    result = callback(context, input_data, error)
                    if hasattr(result, "__await__"):
                        result = await result  # type: ignore

                if isinstance(result, Message) or result is None:
                    recovery_value = result
            except Exception as exc:
                logger.exception("Error callback failed: %s", exc)
                continue

        return recovery_value

    def clear_callbacks(self, invocation_type: InvocationType | None = None) -> None:
        """
        Clear callbacks for a specific invocation type or all types.

        Args:
            invocation_type (InvocationType | None): The invocation type to clear, or None for all.
        """
        if invocation_type:
            self._before_callbacks[invocation_type].clear()
            self._after_callbacks[invocation_type].clear()
            self._error_callbacks[invocation_type].clear()
        else:
            for inv_type in InvocationType:
                self._before_callbacks[inv_type].clear()
                self._after_callbacks[inv_type].clear()
                self._error_callbacks[inv_type].clear()

    def get_callback_counts(self) -> dict[str, dict[str, int]]:
        """
        Get count of registered callbacks by type for debugging.

        Returns:
            dict[str, dict[str, int]]: Counts of callbacks for each invocation type.
        """
        return {
            inv_type.value: {
                "before_invoke": len(self._before_callbacks[inv_type]),
                "after_invoke": len(self._after_callbacks[inv_type]),
                "on_error": len(self._error_callbacks[inv_type]),
            }
            for inv_type in InvocationType
        }
Functions
__init__
__init__()

Initialize the CallbackManager with empty callback registries.

Source code in pyagenity/utils/callbacks.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def __init__(self):
    """
    Initialize the CallbackManager with empty callback registries.
    """
    self._before_callbacks: dict[InvocationType, list[BeforeInvokeCallbackType]] = {
        InvocationType.AI: [],
        InvocationType.TOOL: [],
        InvocationType.MCP: [],
    }
    self._after_callbacks: dict[InvocationType, list[AfterInvokeCallbackType]] = {
        InvocationType.AI: [],
        InvocationType.TOOL: [],
        InvocationType.MCP: [],
    }
    self._error_callbacks: dict[InvocationType, list[OnErrorCallbackType]] = {
        InvocationType.AI: [],
        InvocationType.TOOL: [],
        InvocationType.MCP: [],
    }
clear_callbacks
clear_callbacks(invocation_type=None)

Clear callbacks for a specific invocation type or all types.

Parameters:

Name Type Description Default
invocation_type InvocationType | None

The invocation type to clear, or None for all.

None
Source code in pyagenity/utils/callbacks.py
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
def clear_callbacks(self, invocation_type: InvocationType | None = None) -> None:
    """
    Clear callbacks for a specific invocation type or all types.

    Args:
        invocation_type (InvocationType | None): The invocation type to clear, or None for all.
    """
    if invocation_type:
        self._before_callbacks[invocation_type].clear()
        self._after_callbacks[invocation_type].clear()
        self._error_callbacks[invocation_type].clear()
    else:
        for inv_type in InvocationType:
            self._before_callbacks[inv_type].clear()
            self._after_callbacks[inv_type].clear()
            self._error_callbacks[inv_type].clear()
execute_after_invoke async
execute_after_invoke(context, input_data, output_data)

Execute all after_invoke callbacks for the given context.

Parameters:

Name Type Description Default
context CallbackContext

Context information about the invocation.

required
input_data Any

The original input data sent to the invocation.

required
output_data Any

The output data returned from the invocation.

required

Returns:

Name Type Description
Any Any

The modified output data after all callbacks.

Raises:

Type Description
Exception

If any callback fails.

Source code in pyagenity/utils/callbacks.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
async def execute_after_invoke(
    self, context: CallbackContext, input_data: Any, output_data: Any
) -> Any:
    """
    Execute all after_invoke callbacks for the given context.

    Args:
        context (CallbackContext): Context information about the invocation.
        input_data (Any): The original input data sent to the invocation.
        output_data (Any): The output data returned from the invocation.

    Returns:
        Any: The modified output data after all callbacks.

    Raises:
        Exception: If any callback fails.
    """
    current_output = output_data

    for callback in self._after_callbacks[context.invocation_type]:
        try:
            if isinstance(callback, AfterInvokeCallback):
                current_output = await callback(context, input_data, current_output)
            elif callable(callback):
                result = callback(context, input_data, current_output)
                if hasattr(result, "__await__"):
                    current_output = await result
                else:
                    current_output = result
        except Exception as e:
            await self.execute_on_error(context, input_data, e)
            raise

    return current_output
execute_before_invoke async
execute_before_invoke(context, input_data)

Execute all before_invoke callbacks for the given context.

Parameters:

Name Type Description Default
context CallbackContext

Context information about the invocation.

required
input_data Any

The input data to be validated or modified.

required

Returns:

Name Type Description
Any Any

The modified input data after all callbacks.

Raises:

Type Description
Exception

If any callback fails.

Source code in pyagenity/utils/callbacks.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
async def execute_before_invoke(self, context: CallbackContext, input_data: Any) -> Any:
    """
    Execute all before_invoke callbacks for the given context.

    Args:
        context (CallbackContext): Context information about the invocation.
        input_data (Any): The input data to be validated or modified.

    Returns:
        Any: The modified input data after all callbacks.

    Raises:
        Exception: If any callback fails.
    """
    current_data = input_data

    for callback in self._before_callbacks[context.invocation_type]:
        try:
            if isinstance(callback, BeforeInvokeCallback):
                current_data = await callback(context, current_data)
            elif callable(callback):
                result = callback(context, current_data)
                if hasattr(result, "__await__"):
                    current_data = await result
                else:
                    current_data = result
        except Exception as e:
            await self.execute_on_error(context, input_data, e)
            raise

    return current_data
execute_on_error async
execute_on_error(context, input_data, error)

Execute all on_error callbacks for the given context.

Parameters:

Name Type Description Default
context CallbackContext

Context information about the invocation.

required
input_data Any

The input data that caused the error.

required
error Exception

The exception that occurred.

required

Returns:

Type Description
Message | None

Message | None: Recovery value from callbacks, or None if not handled.

Source code in pyagenity/utils/callbacks.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
async def execute_on_error(
    self, context: CallbackContext, input_data: Any, error: Exception
) -> Message | None:
    """
    Execute all on_error callbacks for the given context.

    Args:
        context (CallbackContext): Context information about the invocation.
        input_data (Any): The input data that caused the error.
        error (Exception): The exception that occurred.

    Returns:
        Message | None: Recovery value from callbacks, or None if not handled.
    """
    recovery_value = None

    for callback in self._error_callbacks[context.invocation_type]:
        try:
            result = None
            if isinstance(callback, OnErrorCallback):
                result = await callback(context, input_data, error)
            elif callable(callback):
                result = callback(context, input_data, error)
                if hasattr(result, "__await__"):
                    result = await result  # type: ignore

            if isinstance(result, Message) or result is None:
                recovery_value = result
        except Exception as exc:
            logger.exception("Error callback failed: %s", exc)
            continue

    return recovery_value
get_callback_counts
get_callback_counts()

Get count of registered callbacks by type for debugging.

Returns:

Type Description
dict[str, dict[str, int]]

dict[str, dict[str, int]]: Counts of callbacks for each invocation type.

Source code in pyagenity/utils/callbacks.py
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
def get_callback_counts(self) -> dict[str, dict[str, int]]:
    """
    Get count of registered callbacks by type for debugging.

    Returns:
        dict[str, dict[str, int]]: Counts of callbacks for each invocation type.
    """
    return {
        inv_type.value: {
            "before_invoke": len(self._before_callbacks[inv_type]),
            "after_invoke": len(self._after_callbacks[inv_type]),
            "on_error": len(self._error_callbacks[inv_type]),
        }
        for inv_type in InvocationType
    }
register_after_invoke
register_after_invoke(invocation_type, callback)

Register an after_invoke callback for a specific invocation type.

Parameters:

Name Type Description Default
invocation_type InvocationType

The type of invocation (AI, TOOL, MCP).

required
callback AfterInvokeCallbackType

The callback to register.

required
Source code in pyagenity/utils/callbacks.py
176
177
178
179
180
181
182
183
184
185
186
def register_after_invoke(
    self, invocation_type: InvocationType, callback: AfterInvokeCallbackType
) -> None:
    """
    Register an after_invoke callback for a specific invocation type.

    Args:
        invocation_type (InvocationType): The type of invocation (AI, TOOL, MCP).
        callback (AfterInvokeCallbackType): The callback to register.
    """
    self._after_callbacks[invocation_type].append(callback)
register_before_invoke
register_before_invoke(invocation_type, callback)

Register a before_invoke callback for a specific invocation type.

Parameters:

Name Type Description Default
invocation_type InvocationType

The type of invocation (AI, TOOL, MCP).

required
callback BeforeInvokeCallbackType

The callback to register.

required
Source code in pyagenity/utils/callbacks.py
164
165
166
167
168
169
170
171
172
173
174
def register_before_invoke(
    self, invocation_type: InvocationType, callback: BeforeInvokeCallbackType
) -> None:
    """
    Register a before_invoke callback for a specific invocation type.

    Args:
        invocation_type (InvocationType): The type of invocation (AI, TOOL, MCP).
        callback (BeforeInvokeCallbackType): The callback to register.
    """
    self._before_callbacks[invocation_type].append(callback)
register_on_error
register_on_error(invocation_type, callback)

Register an on_error callback for a specific invocation type.

Parameters:

Name Type Description Default
invocation_type InvocationType

The type of invocation (AI, TOOL, MCP).

required
callback OnErrorCallbackType

The callback to register.

required
Source code in pyagenity/utils/callbacks.py
188
189
190
191
192
193
194
195
196
197
198
def register_on_error(
    self, invocation_type: InvocationType, callback: OnErrorCallbackType
) -> None:
    """
    Register an on_error callback for a specific invocation type.

    Args:
        invocation_type (InvocationType): The type of invocation (AI, TOOL, MCP).
        callback (OnErrorCallbackType): The callback to register.
    """
    self._error_callbacks[invocation_type].append(callback)
InvocationType

Bases: Enum

Types of invocations that can trigger callbacks.

Attributes:

Name Type Description
AI
MCP
TOOL
Source code in pyagenity/utils/callbacks.py
28
29
30
31
32
33
class InvocationType(Enum):
    """Types of invocations that can trigger callbacks."""

    AI = "ai"
    TOOL = "tool"
    MCP = "mcp"
Attributes
AI class-attribute instance-attribute
AI = 'ai'
MCP class-attribute instance-attribute
MCP = 'mcp'
TOOL class-attribute instance-attribute
TOOL = 'tool'
OnErrorCallback

Bases: ABC

Abstract base class for on_error callbacks.

Called when an error occurs during invocation. Allows for error handling and logging.

Methods:

Name Description
__call__

Execute the on_error callback.

Source code in pyagenity/utils/callbacks.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
class OnErrorCallback(ABC):
    """Abstract base class for on_error callbacks.

    Called when an error occurs during invocation.
    Allows for error handling and logging.
    """

    @abstractmethod
    async def __call__(
        self, context: CallbackContext, input_data: Any, error: Exception
    ) -> Any | None:
        """Execute the on_error callback.

        Args:
            context: Context information about the invocation
            input_data: The input data that caused the error
            error: The exception that occurred

        Returns:
            Optional recovery value or None to re-raise the error

        Raises:
            Exception: If error handling fails or if the error should be re-raised
        """
        ...
Functions
__call__ abstractmethod async
__call__(context, input_data, error)

Execute the on_error callback.

Parameters:

Name Type Description Default
context CallbackContext

Context information about the invocation

required
input_data Any

The input data that caused the error

required
error Exception

The exception that occurred

required

Returns:

Type Description
Any | None

Optional recovery value or None to re-raise the error

Raises:

Type Description
Exception

If error handling fails or if the error should be re-raised

Source code in pyagenity/utils/callbacks.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
@abstractmethod
async def __call__(
    self, context: CallbackContext, input_data: Any, error: Exception
) -> Any | None:
    """Execute the on_error callback.

    Args:
        context: Context information about the invocation
        input_data: The input data that caused the error
        error: The exception that occurred

    Returns:
        Optional recovery value or None to re-raise the error

    Raises:
        Exception: If error handling fails or if the error should be re-raised
    """
    ...
Functions
register_after_invoke
register_after_invoke(invocation_type, callback)

Register an after_invoke callback on the global callback manager.

Parameters:

Name Type Description Default
invocation_type InvocationType

The type of invocation (AI, TOOL, MCP).

required
callback AfterInvokeCallbackType

The callback to register.

required
Source code in pyagenity/utils/callbacks.py
353
354
355
356
357
358
359
360
361
362
363
def register_after_invoke(
    invocation_type: InvocationType, callback: AfterInvokeCallbackType
) -> None:
    """
    Register an after_invoke callback on the global callback manager.

    Args:
        invocation_type (InvocationType): The type of invocation (AI, TOOL, MCP).
        callback (AfterInvokeCallbackType): The callback to register.
    """
    default_callback_manager.register_after_invoke(invocation_type, callback)
register_before_invoke
register_before_invoke(invocation_type, callback)

Register a before_invoke callback on the global callback manager.

Parameters:

Name Type Description Default
invocation_type InvocationType

The type of invocation (AI, TOOL, MCP).

required
callback BeforeInvokeCallbackType

The callback to register.

required
Source code in pyagenity/utils/callbacks.py
340
341
342
343
344
345
346
347
348
349
350
def register_before_invoke(
    invocation_type: InvocationType, callback: BeforeInvokeCallbackType
) -> None:
    """
    Register a before_invoke callback on the global callback manager.

    Args:
        invocation_type (InvocationType): The type of invocation (AI, TOOL, MCP).
        callback (BeforeInvokeCallbackType): The callback to register.
    """
    default_callback_manager.register_before_invoke(invocation_type, callback)
register_on_error
register_on_error(invocation_type, callback)

Register an on_error callback on the global callback manager.

Parameters:

Name Type Description Default
invocation_type InvocationType

The type of invocation (AI, TOOL, MCP).

required
callback OnErrorCallbackType

The callback to register.

required
Source code in pyagenity/utils/callbacks.py
366
367
368
369
370
371
372
373
374
def register_on_error(invocation_type: InvocationType, callback: OnErrorCallbackType) -> None:
    """
    Register an on_error callback on the global callback manager.

    Args:
        invocation_type (InvocationType): The type of invocation (AI, TOOL, MCP).
        callback (OnErrorCallbackType): The callback to register.
    """
    default_callback_manager.register_on_error(invocation_type, callback)
command

Command API for AgentGraph in PyAgenity.

This module provides the Command class, which allows nodes to combine state updates with control flow, similar to LangGraph's Command API. Nodes can update agent state and direct graph execution to specific nodes or graphs.

Classes:

Name Description
Command

Command object that combines state updates with control flow.

Attributes:

Name Type Description
StateT
Attributes
StateT module-attribute
StateT = TypeVar('StateT', bound='AgentState')
Classes
Command

Command object that combines state updates with control flow.

Allows nodes to update agent state and direct graph execution to specific nodes or graphs. Similar to LangGraph's Command API.

Methods:

Name Description
__init__

Initialize a Command object.

__repr__

Return a string representation of the Command object.

Attributes:

Name Type Description
PARENT
goto
graph
state
update
Source code in pyagenity/utils/command.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class Command[StateT: AgentState]:
    """
    Command object that combines state updates with control flow.

    Allows nodes to update agent state and direct graph execution to specific nodes or graphs.
    Similar to LangGraph's Command API.
    """

    PARENT = "PARENT"

    def __init__(
        self,
        update: Union["StateT", None, Message, str, "BaseConverter"] = None,
        goto: str | None = None,
        graph: str | None = None,
        state: StateT | None = None,
    ):
        """
        Initialize a Command object.

        Args:
            update (StateT | None | Message | str | BaseConverter): State update to apply.
            goto (str | None): Next node to execute (node name or END).
            graph (str | None): Which graph to navigate to (None for current, PARENT for parent).
            state (StateT | None): Optional agent state to attach.
        """
        self.update = update
        self.goto = goto
        self.graph = graph
        self.state = state

    def __repr__(self) -> str:
        """
        Return a string representation of the Command object.

        Returns:
            str: String representation of the Command.
        """
        return (
            f"Command(update={self.update}, goto={self.goto}, \n"
            f" graph={self.graph}, state={self.state})"
        )
Attributes
PARENT class-attribute instance-attribute
PARENT = 'PARENT'
goto instance-attribute
goto = goto
graph instance-attribute
graph = graph
state instance-attribute
state = state
update instance-attribute
update = update
Functions
__init__
__init__(update=None, goto=None, graph=None, state=None)

Initialize a Command object.

Parameters:

Name Type Description Default
update StateT | None | Message | str | BaseConverter

State update to apply.

None
goto str | None

Next node to execute (node name or END).

None
graph str | None

Which graph to navigate to (None for current, PARENT for parent).

None
state StateT | None

Optional agent state to attach.

None
Source code in pyagenity/utils/command.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def __init__(
    self,
    update: Union["StateT", None, Message, str, "BaseConverter"] = None,
    goto: str | None = None,
    graph: str | None = None,
    state: StateT | None = None,
):
    """
    Initialize a Command object.

    Args:
        update (StateT | None | Message | str | BaseConverter): State update to apply.
        goto (str | None): Next node to execute (node name or END).
        graph (str | None): Which graph to navigate to (None for current, PARENT for parent).
        state (StateT | None): Optional agent state to attach.
    """
    self.update = update
    self.goto = goto
    self.graph = graph
    self.state = state
__repr__
__repr__()

Return a string representation of the Command object.

Returns:

Name Type Description
str str

String representation of the Command.

Source code in pyagenity/utils/command.py
54
55
56
57
58
59
60
61
62
63
64
def __repr__(self) -> str:
    """
    Return a string representation of the Command object.

    Returns:
        str: String representation of the Command.
    """
    return (
        f"Command(update={self.update}, goto={self.goto}, \n"
        f" graph={self.graph}, state={self.state})"
    )
constants

Constants and enums for PyAgenity agent graph execution and messaging.

This module defines special node names, message storage levels, execution states, and response granularity options for agent workflows.

Classes:

Name Description
ExecutionState

Graph execution states for agent workflows.

ResponseGranularity

Response granularity options for agent graph outputs.

StorageLevel

Message storage levels for agent state persistence.

Attributes:

Name Type Description
END Literal['__end__']
START Literal['__start__']
Attributes
END module-attribute
END = '__end__'
START module-attribute
START = '__start__'
Classes
ExecutionState

Bases: StrEnum

Graph execution states for agent workflows.

Values

RUNNING: Execution is in progress. PAUSED: Execution is paused. COMPLETED: Execution completed successfully. ERROR: Execution encountered an error. INTERRUPTED: Execution was interrupted. ABORTED: Execution was aborted. IDLE: Execution is idle.

Attributes:

Name Type Description
ABORTED
COMPLETED
ERROR
IDLE
INTERRUPTED
PAUSED
RUNNING
Source code in pyagenity/utils/constants.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class ExecutionState(StrEnum):
    """
    Graph execution states for agent workflows.

    Values:
        RUNNING: Execution is in progress.
        PAUSED: Execution is paused.
        COMPLETED: Execution completed successfully.
        ERROR: Execution encountered an error.
        INTERRUPTED: Execution was interrupted.
        ABORTED: Execution was aborted.
        IDLE: Execution is idle.
    """

    RUNNING = "running"
    PAUSED = "paused"
    COMPLETED = "completed"
    ERROR = "error"
    INTERRUPTED = "interrupted"
    ABORTED = "aborted"
    IDLE = "idle"
Attributes
ABORTED class-attribute instance-attribute
ABORTED = 'aborted'
COMPLETED class-attribute instance-attribute
COMPLETED = 'completed'
ERROR class-attribute instance-attribute
ERROR = 'error'
IDLE class-attribute instance-attribute
IDLE = 'idle'
INTERRUPTED class-attribute instance-attribute
INTERRUPTED = 'interrupted'
PAUSED class-attribute instance-attribute
PAUSED = 'paused'
RUNNING class-attribute instance-attribute
RUNNING = 'running'
ResponseGranularity

Bases: StrEnum

Response granularity options for agent graph outputs.

Values

FULL: State, latest messages. PARTIAL: Context, summary, latest messages. LOW: Only latest messages.

Attributes:

Name Type Description
FULL
LOW
PARTIAL
Source code in pyagenity/utils/constants.py
55
56
57
58
59
60
61
62
63
64
65
66
67
class ResponseGranularity(StrEnum):
    """
    Response granularity options for agent graph outputs.

    Values:
        FULL: State, latest messages.
        PARTIAL: Context, summary, latest messages.
        LOW: Only latest messages.
    """

    FULL = "full"
    PARTIAL = "partial"
    LOW = "low"
Attributes
FULL class-attribute instance-attribute
FULL = 'full'
LOW class-attribute instance-attribute
LOW = 'low'
PARTIAL class-attribute instance-attribute
PARTIAL = 'partial'
StorageLevel

Message storage levels for agent state persistence.

Attributes:

Name Type Description
ALL

Save everything including tool calls.

MEDIUM

Only AI and human messages.

LOW

Only first human and last AI message.

Source code in pyagenity/utils/constants.py
17
18
19
20
21
22
23
24
25
26
27
28
29
class StorageLevel:
    """
    Message storage levels for agent state persistence.

    Attributes:
        ALL: Save everything including tool calls.
        MEDIUM: Only AI and human messages.
        LOW: Only first human and last AI message.
    """

    ALL = "all"
    MEDIUM = "medium"
    LOW = "low"
Attributes
ALL class-attribute instance-attribute
ALL = 'all'
LOW class-attribute instance-attribute
LOW = 'low'
MEDIUM class-attribute instance-attribute
MEDIUM = 'medium'
converter

Message conversion utilities for PyAgenity agent graphs.

This module provides helpers to convert Message objects and agent state into dicts suitable for LLM and tool invocation payloads.

Functions:

Name Description
convert_messages

Convert system prompts, agent state, and extra messages to a list of dicts for

Attributes:

Name Type Description
logger
Attributes
logger module-attribute
logger = getLogger(__name__)
Classes
Functions
convert_messages
convert_messages(system_prompts, state=None, extra_messages=None)

Convert system prompts, agent state, and extra messages to a list of dicts for LLM/tool payloads.

Parameters:

Name Type Description Default
system_prompts list[dict[str, Any]]

List of system prompt dicts.

required
state AgentState | None

Optional agent state containing context and summary.

None
extra_messages list[Message] | None

Optional extra messages to include.

None

Returns:

Type Description
list[dict[str, Any]]

list[dict[str, Any]]: List of message dicts for payloads.

Raises:

Type Description
ValueError

If system_prompts is None.

Source code in pyagenity/utils/converter.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def convert_messages(
    system_prompts: list[dict[str, Any]],
    state: Union["AgentState", None] = None,
    extra_messages: list[Message] | None = None,
) -> list[dict[str, Any]]:
    """
    Convert system prompts, agent state, and extra messages to a list of dicts for
    LLM/tool payloads.

    Args:
        system_prompts (list[dict[str, Any]]): List of system prompt dicts.
        state (AgentState | None): Optional agent state containing context and summary.
        extra_messages (list[Message] | None): Optional extra messages to include.

    Returns:
        list[dict[str, Any]]: List of message dicts for payloads.

    Raises:
        ValueError: If system_prompts is None.
    """
    if system_prompts is None:
        logger.error("System prompts are None")
        raise ValueError("System prompts cannot be None")

    res = []
    res += system_prompts

    if state and state.context_summary:
        summary = {
            "role": "assistant",
            "content": state.context_summary if state.context_summary else "",
        }
        res.append(summary)

    if state and state.context:
        for msg in state.context:
            res.append(_convert_dict(msg))

    if extra_messages:
        for msg in extra_messages:
            res.append(_convert_dict(msg))

    logger.debug("Number of Converted messages: %s", len(res))
    return res
id_generator

ID Generator Module

This module provides various strategies for generating unique identifiers. Each generator implements the BaseIDGenerator interface and specifies the type and size of IDs it produces.

Classes:

Name Description
AsyncIDGenerator

ID generator that produces UUID version 4 strings asynchronously.

BaseIDGenerator

Abstract base class for ID generation strategies.

BigIntIDGenerator

ID generator that produces big integer IDs based on current time in nanoseconds.

DefaultIDGenerator

Default ID generator that returns empty strings.

HexIDGenerator

ID generator that produces hexadecimal strings.

IDType

Enumeration of supported ID types.

IntIDGenerator

ID generator that produces 32-bit random integers.

ShortIDGenerator

ID generator that produces short alphanumeric strings.

TimestampIDGenerator

ID generator that produces integer IDs based on current time in microseconds.

UUIDGenerator

ID generator that produces UUID version 4 strings.

Classes
AsyncIDGenerator

Bases: BaseIDGenerator

ID generator that produces UUID version 4 strings asynchronously.

UUIDs are 128-bit identifiers that are virtually guaranteed to be unique across space and time. The generated strings are 36 characters long (32 hexadecimal digits + 4 hyphens in the format xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx). This generator provides an asynchronous interface for generating UUIDs.

Methods:

Name Description
generate

Asynchronously generate a new UUID4 string.

Attributes:

Name Type Description
id_type IDType

Return the type of ID generated by this generator.

Source code in pyagenity/utils/id_generator.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
class AsyncIDGenerator(BaseIDGenerator):
    """
    ID generator that produces UUID version 4 strings asynchronously.

    UUIDs are 128-bit identifiers that are virtually guaranteed to be unique
    across space and time. The generated strings are 36 characters long
    (32 hexadecimal digits + 4 hyphens in the format xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx).
    This generator provides an asynchronous interface for generating UUIDs.
    """

    @property
    def id_type(self) -> IDType:
        """
        Return the type of ID generated by this generator.

        Returns:
            IDType: The type of ID (STRING).
        """
        return IDType.STRING

    async def generate(self) -> str:
        """
        Asynchronously generate a new UUID4 string.

        Returns:
            str: A 36-character UUID string (e.g., '550e8400-e29b-41d4-a716-446655440000').
        """
        # Simulate async operation (e.g., if fetching from an external service)
        return str(uuid.uuid4())
Attributes
id_type property
id_type

Return the type of ID generated by this generator.

Returns:

Name Type Description
IDType IDType

The type of ID (STRING).

Functions
generate async
generate()

Asynchronously generate a new UUID4 string.

Returns:

Name Type Description
str str

A 36-character UUID string (e.g., '550e8400-e29b-41d4-a716-446655440000').

Source code in pyagenity/utils/id_generator.py
226
227
228
229
230
231
232
233
234
async def generate(self) -> str:
    """
    Asynchronously generate a new UUID4 string.

    Returns:
        str: A 36-character UUID string (e.g., '550e8400-e29b-41d4-a716-446655440000').
    """
    # Simulate async operation (e.g., if fetching from an external service)
    return str(uuid.uuid4())
BaseIDGenerator

Bases: ABC

Abstract base class for ID generation strategies.

All ID generators must implement the id_type property and generate method.

Methods:

Name Description
generate

Generate a new unique ID.

Attributes:

Name Type Description
id_type IDType

Return the type of ID generated by this generator.

Source code in pyagenity/utils/id_generator.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class BaseIDGenerator(ABC):
    """Abstract base class for ID generation strategies.

    All ID generators must implement the id_type property and generate method.
    """

    @property
    @abstractmethod
    def id_type(self) -> IDType:
        """Return the type of ID generated by this generator.

        Returns:
            IDType: The type of ID (STRING, INTEGER, or BIGINT).
        """
        raise NotImplementedError("id_type method must be implemented")

    @abstractmethod
    def generate(self) -> str | int | Awaitable[str | int]:
        """Generate a new unique ID.

        Returns:
            str | int: A new unique identifier of the appropriate type.
        """
        raise NotImplementedError("generate method must be implemented")
Attributes
id_type abstractmethod property
id_type

Return the type of ID generated by this generator.

Returns:

Name Type Description
IDType IDType

The type of ID (STRING, INTEGER, or BIGINT).

Functions
generate abstractmethod
generate()

Generate a new unique ID.

Returns:

Type Description
str | int | Awaitable[str | int]

str | int: A new unique identifier of the appropriate type.

Source code in pyagenity/utils/id_generator.py
41
42
43
44
45
46
47
48
@abstractmethod
def generate(self) -> str | int | Awaitable[str | int]:
    """Generate a new unique ID.

    Returns:
        str | int: A new unique identifier of the appropriate type.
    """
    raise NotImplementedError("generate method must be implemented")
BigIntIDGenerator

Bases: BaseIDGenerator

ID generator that produces big integer IDs based on current time in nanoseconds.

Generates IDs by multiplying current Unix timestamp by 1e9, resulting in large integers that are sortable by creation time. Typical size is 19-20 digits.

Methods:

Name Description
generate

Generate a new big integer ID based on current nanoseconds.

Attributes:

Name Type Description
id_type IDType
Source code in pyagenity/utils/id_generator.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
class BigIntIDGenerator(BaseIDGenerator):
    """ID generator that produces big integer IDs based on current time in nanoseconds.

    Generates IDs by multiplying current Unix timestamp by 1e9, resulting in
    large integers that are sortable by creation time. Typical size is 19-20 digits.
    """

    @property
    def id_type(self) -> IDType:
        return IDType.BIGINT

    def generate(self) -> int:
        """Generate a new big integer ID based on current nanoseconds.

        Returns:
            int: A large integer (19-20 digits) representing nanoseconds since Unix epoch.
        """
        # Use current time in nanoseconds for higher uniqueness
        return int(time.time() * 1_000_000_000)
Attributes
id_type property
id_type
Functions
generate
generate()

Generate a new big integer ID based on current nanoseconds.

Returns:

Name Type Description
int int

A large integer (19-20 digits) representing nanoseconds since Unix epoch.

Source code in pyagenity/utils/id_generator.py
83
84
85
86
87
88
89
90
def generate(self) -> int:
    """Generate a new big integer ID based on current nanoseconds.

    Returns:
        int: A large integer (19-20 digits) representing nanoseconds since Unix epoch.
    """
    # Use current time in nanoseconds for higher uniqueness
    return int(time.time() * 1_000_000_000)
DefaultIDGenerator

Bases: BaseIDGenerator

Default ID generator that returns empty strings.

This generator is intended as a placeholder that can be configured to use framework defaults (typically UUID-based). Currently returns empty strings. If empty string is returned, the framework will use its default UUID-based generator. If the framework is not configured to use UUID generation, it will fall back to UUID4.

Methods:

Name Description
generate

Generate a default ID (currently empty string).

Attributes:

Name Type Description
id_type IDType
Source code in pyagenity/utils/id_generator.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
class DefaultIDGenerator(BaseIDGenerator):
    """Default ID generator that returns empty strings.

    This generator is intended as a placeholder that can be configured
    to use framework defaults (typically UUID-based). Currently returns
    empty strings. If empty string is returned, the framework will use its default
    UUID-based generator. If the framework is not configured to use
    UUID generation, it will fall back to UUID4.
    """

    @property
    def id_type(self) -> IDType:
        return IDType.STRING

    def generate(self) -> str:
        """Generate a default ID (currently empty string).

        If empty string is returned, the framework will use its default
        UUID-based generator. If the framework is not configured to use
        UUID generation, it will fall back to UUID4.

        Returns:
            str: An empty string (framework will substitute with UUID).
        """
        # if you keep empty, then it will be used default
        # framework default which is UUID based
        # if framework not using then uuid 4 will be used
        return ""
Attributes
id_type property
id_type
Functions
generate
generate()

Generate a default ID (currently empty string).

If empty string is returned, the framework will use its default UUID-based generator. If the framework is not configured to use UUID generation, it will fall back to UUID4.

Returns:

Name Type Description
str str

An empty string (framework will substitute with UUID).

Source code in pyagenity/utils/id_generator.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def generate(self) -> str:
    """Generate a default ID (currently empty string).

    If empty string is returned, the framework will use its default
    UUID-based generator. If the framework is not configured to use
    UUID generation, it will fall back to UUID4.

    Returns:
        str: An empty string (framework will substitute with UUID).
    """
    # if you keep empty, then it will be used default
    # framework default which is UUID based
    # if framework not using then uuid 4 will be used
    return ""
HexIDGenerator

Bases: BaseIDGenerator

ID generator that produces hexadecimal strings.

Generates cryptographically secure random hex strings of 32 characters (representing 16 random bytes). Each character is a hexadecimal digit (0-9, a-f).

Methods:

Name Description
generate

Generate a new 32-character hexadecimal string.

Attributes:

Name Type Description
id_type IDType
Source code in pyagenity/utils/id_generator.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
class HexIDGenerator(BaseIDGenerator):
    """ID generator that produces hexadecimal strings.

    Generates cryptographically secure random hex strings of 32 characters
    (representing 16 random bytes). Each character is a hexadecimal digit (0-9, a-f).
    """

    @property
    def id_type(self) -> IDType:
        return IDType.STRING

    def generate(self) -> str:
        """Generate a new 32-character hexadecimal string.

        Returns:
            str: A 32-character hex string (e.g., '1a2b3c4d5e6f7890abcdef1234567890').
        """
        return secrets.token_hex(16)
Attributes
id_type property
id_type
Functions
generate
generate()

Generate a new 32-character hexadecimal string.

Returns:

Name Type Description
str str

A 32-character hex string (e.g., '1a2b3c4d5e6f7890abcdef1234567890').

Source code in pyagenity/utils/id_generator.py
154
155
156
157
158
159
160
def generate(self) -> str:
    """Generate a new 32-character hexadecimal string.

    Returns:
        str: A 32-character hex string (e.g., '1a2b3c4d5e6f7890abcdef1234567890').
    """
    return secrets.token_hex(16)
IDType

Bases: StrEnum

Enumeration of supported ID types.

Attributes:

Name Type Description
BIGINT
INTEGER
STRING
Source code in pyagenity/utils/id_generator.py
17
18
19
20
21
22
class IDType(enum.StrEnum):
    """Enumeration of supported ID types."""

    STRING = "string"  # String-based IDs
    INTEGER = "integer"  # Integer-based IDs
    BIGINT = "bigint"  # Big integer IDs
Attributes
BIGINT class-attribute instance-attribute
BIGINT = 'bigint'
INTEGER class-attribute instance-attribute
INTEGER = 'integer'
STRING class-attribute instance-attribute
STRING = 'string'
IntIDGenerator

Bases: BaseIDGenerator

ID generator that produces 32-bit random integers.

Generates cryptographically secure random integers using secrets.randbits(32). Values range from 0 to 2^32 - 1 (4,294,967,295).

Methods:

Name Description
generate

Generate a new 32-bit random integer.

Attributes:

Name Type Description
id_type IDType
Source code in pyagenity/utils/id_generator.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
class IntIDGenerator(BaseIDGenerator):
    """ID generator that produces 32-bit random integers.

    Generates cryptographically secure random integers using secrets.randbits(32).
    Values range from 0 to 2^32 - 1 (4,294,967,295).
    """

    @property
    def id_type(self) -> IDType:
        return IDType.INTEGER

    def generate(self) -> int:
        """Generate a new 32-bit random integer.

        Returns:
            int: A random integer between 0 and 4,294,967,295 (inclusive).
        """
        return secrets.randbits(32)
Attributes
id_type property
id_type
Functions
generate
generate()

Generate a new 32-bit random integer.

Returns:

Name Type Description
int int

A random integer between 0 and 4,294,967,295 (inclusive).

Source code in pyagenity/utils/id_generator.py
134
135
136
137
138
139
140
def generate(self) -> int:
    """Generate a new 32-bit random integer.

    Returns:
        int: A random integer between 0 and 4,294,967,295 (inclusive).
    """
    return secrets.randbits(32)
ShortIDGenerator

Bases: BaseIDGenerator

ID generator that produces short alphanumeric strings.

Generates 8-character strings using uppercase/lowercase letters and digits. Each character is randomly chosen from 62 possible characters (26 + 26 + 10). Total possible combinations: 62^8 ≈ 2.18 x 10^14.

Methods:

Name Description
generate

Generate a new 8-character alphanumeric string.

Attributes:

Name Type Description
id_type IDType
Source code in pyagenity/utils/id_generator.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
class ShortIDGenerator(BaseIDGenerator):
    """ID generator that produces short alphanumeric strings.

    Generates 8-character strings using uppercase/lowercase letters and digits.
    Each character is randomly chosen from 62 possible characters (26 + 26 + 10).
    Total possible combinations: 62^8 ≈ 2.18 x 10^14.
    """

    @property
    def id_type(self) -> IDType:
        return IDType.STRING

    def generate(self) -> str:
        """Generate a new 8-character alphanumeric string.

        Returns:
            str: An 8-character string containing letters and digits
                 (e.g., 'Ab3XyZ9k').
        """
        alphabet = string.ascii_letters + string.digits
        return "".join(secrets.choice(alphabet) for _ in range(8))
Attributes
id_type property
id_type
Functions
generate
generate()

Generate a new 8-character alphanumeric string.

Returns:

Name Type Description
str str

An 8-character string containing letters and digits (e.g., 'Ab3XyZ9k').

Source code in pyagenity/utils/id_generator.py
195
196
197
198
199
200
201
202
203
def generate(self) -> str:
    """Generate a new 8-character alphanumeric string.

    Returns:
        str: An 8-character string containing letters and digits
             (e.g., 'Ab3XyZ9k').
    """
    alphabet = string.ascii_letters + string.digits
    return "".join(secrets.choice(alphabet) for _ in range(8))
TimestampIDGenerator

Bases: BaseIDGenerator

ID generator that produces integer IDs based on current time in microseconds.

Generates IDs by multiplying current Unix timestamp by 1e6, resulting in integers that are sortable by creation time. Typical size is 16-17 digits.

Methods:

Name Description
generate

Generate a new integer ID based on current microseconds.

Attributes:

Name Type Description
id_type IDType
Source code in pyagenity/utils/id_generator.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
class TimestampIDGenerator(BaseIDGenerator):
    """ID generator that produces integer IDs based on current time in microseconds.

    Generates IDs by multiplying current Unix timestamp by 1e6, resulting in
    integers that are sortable by creation time. Typical size is 16-17 digits.
    """

    @property
    def id_type(self) -> IDType:
        return IDType.INTEGER

    def generate(self) -> int:
        """Generate a new integer ID based on current microseconds.

        Returns:
            int: An integer (16-17 digits) representing microseconds since Unix epoch.
        """
        return int(time.time() * 1000000)
Attributes
id_type property
id_type
Functions
generate
generate()

Generate a new integer ID based on current microseconds.

Returns:

Name Type Description
int int

An integer (16-17 digits) representing microseconds since Unix epoch.

Source code in pyagenity/utils/id_generator.py
174
175
176
177
178
179
180
def generate(self) -> int:
    """Generate a new integer ID based on current microseconds.

    Returns:
        int: An integer (16-17 digits) representing microseconds since Unix epoch.
    """
    return int(time.time() * 1000000)
UUIDGenerator

Bases: BaseIDGenerator

ID generator that produces UUID version 4 strings.

UUIDs are 128-bit identifiers that are virtually guaranteed to be unique across space and time. The generated strings are 36 characters long (32 hexadecimal digits + 4 hyphens in the format xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx).

Methods:

Name Description
generate

Generate a new UUID4 string.

Attributes:

Name Type Description
id_type IDType
Source code in pyagenity/utils/id_generator.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class UUIDGenerator(BaseIDGenerator):
    """ID generator that produces UUID version 4 strings.

    UUIDs are 128-bit identifiers that are virtually guaranteed to be unique
    across space and time. The generated strings are 36 characters long
    (32 hexadecimal digits + 4 hyphens in the format xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx).
    """

    @property
    def id_type(self) -> IDType:
        return IDType.STRING

    def generate(self) -> str:
        """Generate a new UUID4 string.

        Returns:
            str: A 36-character UUID string (e.g., '550e8400-e29b-41d4-a716-446655440000').
        """
        return str(uuid.uuid4())
Attributes
id_type property
id_type
Functions
generate
generate()

Generate a new UUID4 string.

Returns:

Name Type Description
str str

A 36-character UUID string (e.g., '550e8400-e29b-41d4-a716-446655440000').

Source code in pyagenity/utils/id_generator.py
63
64
65
66
67
68
69
def generate(self) -> str:
    """Generate a new UUID4 string.

    Returns:
        str: A 36-character UUID string (e.g., '550e8400-e29b-41d4-a716-446655440000').
    """
    return str(uuid.uuid4())
logging

Centralized logging configuration for PyAgenity.

This module provides logging configuration that can be imported and used throughout the project. Each module should use:

import logging
logger = logging.getLogger(__name__)

This ensures proper hierarchical logging with module-specific loggers.

Typical usage example

from pyagenity.utils.logging import configure_logging configure_logging(level=logging.DEBUG)

Functions:

Name Description
configure_logging

Configures the root logger for the PyAgenity project.

Functions
configure_logging
configure_logging(level=logging.INFO, format_string=None, handler=None)

Configures the root logger for the PyAgenity project.

This function sets up logging for all modules under the 'pyagenity' namespace. It ensures that logs are formatted consistently and sent to the appropriate handler.

Parameters:

Name Type Description Default
level int

Logging level (e.g., logging.INFO, logging.DEBUG). Defaults to logging.INFO.

INFO
format_string str

Custom format string for log messages. If None, uses a default format: "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s".

None
handler Handler

Custom logging handler. If None, uses StreamHandler to stdout.

None

Returns:

Type Description
None

None

Example

configure_logging(level=logging.DEBUG) logger = logging.getLogger("pyagenity.module") logger.info("This is an info message.")

Source code in pyagenity/utils/logging.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def configure_logging(
    level: int = logging.INFO,
    format_string: str | None = None,
    handler: logging.Handler | None = None,
) -> None:
    """
    Configures the root logger for the PyAgenity project.

    This function sets up logging for all modules under the 'pyagenity' namespace.
    It ensures that logs are formatted consistently and sent to the appropriate handler.

    Args:
        level (int, optional): Logging level (e.g., logging.INFO, logging.DEBUG).
            Defaults to logging.INFO.
        format_string (str, optional): Custom format string for log messages.
            If None, uses a default format: "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s".
        handler (logging.Handler, optional): Custom logging handler. If None,
            uses StreamHandler to stdout.

    Returns:
        None

    Raises:
        None

    Example:
        >>> configure_logging(level=logging.DEBUG)
        >>> logger = logging.getLogger("pyagenity.module")
        >>> logger.info("This is an info message.")
    """
    if format_string is None:
        format_string = "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s"

    if handler is None:
        handler = logging.StreamHandler(sys.stdout)

    formatter = logging.Formatter(format_string)
    handler.setFormatter(formatter)

    # Configure root logger for pyagenity
    root_logger = logging.getLogger("pyagenity")
    root_logger.setLevel(level)

    # Only add handler if none exists to avoid duplicates
    if not root_logger.handlers:
        root_logger.addHandler(handler)

    # Prevent propagation to avoid duplicate logs
    root_logger.propagate = False
message

Message and content block primitives for agent graphs.

This module defines the core message representation, multimodal content blocks, token usage tracking, and utility functions for agent graph communication.

Classes:

Name Description
TokenUsages

Tracks token usage statistics for a message or model response.

MediaRef

Reference to media content (image/audio/video/document/data).

AnnotationRef

Reference to annotation metadata.

Message

Represents a message in a conversation, including content, role, metadata, and token usage.

Functions:

Name Description
generate_id

Generates a message or tool call ID based on DI context and type.

Attributes:

Name Type Description
ContentBlock
logger
Attributes
ContentBlock module-attribute
ContentBlock = Annotated[Union[TextBlock, ImageBlock, AudioBlock, VideoBlock, DocumentBlock, DataBlock, ToolCallBlock, ToolResultBlock, ReasoningBlock, AnnotationBlock, ErrorBlock], Field(discriminator='type')]
logger module-attribute
logger = getLogger(__name__)
Classes
AnnotationBlock

Bases: BaseModel

Annotation content block for messages.

Attributes:

Name Type Description
type Literal['annotation']

Block type discriminator.

kind Literal['citation', 'note']

Kind of annotation.

refs list[AnnotationRef]

List of annotation references.

spans list[tuple[int, int]] | None

Spans covered by the annotation.

Source code in pyagenity/utils/message.py
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
class AnnotationBlock(BaseModel):
    """
    Annotation content block for messages.

    Attributes:
        type (Literal["annotation"]): Block type discriminator.
        kind (Literal["citation", "note"]): Kind of annotation.
        refs (list[AnnotationRef]): List of annotation references.
        spans (list[tuple[int, int]] | None): Spans covered by the annotation.
    """

    type: Literal["annotation"] = "annotation"
    kind: Literal["citation", "note"] = "citation"
    refs: list[AnnotationRef] = Field(default_factory=list)
    spans: list[tuple[int, int]] | None = None
Attributes
kind class-attribute instance-attribute
kind = 'citation'
refs class-attribute instance-attribute
refs = Field(default_factory=list)
spans class-attribute instance-attribute
spans = None
type class-attribute instance-attribute
type = 'annotation'
AnnotationRef

Bases: BaseModel

Reference to annotation metadata (e.g., citation, note).

Attributes:

Name Type Description
url str | None

URL to annotation source.

file_id str | None

Provider-managed file ID.

page int | None

Page number (if applicable).

index int | None

Index within the annotation source.

title str | None

Title of the annotation.

Source code in pyagenity/utils/message.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
class AnnotationRef(BaseModel):
    """
    Reference to annotation metadata (e.g., citation, note).

    Attributes:
        url (str | None): URL to annotation source.
        file_id (str | None): Provider-managed file ID.
        page (int | None): Page number (if applicable).
        index (int | None): Index within the annotation source.
        title (str | None): Title of the annotation.
    """

    url: str | None = None
    file_id: str | None = None
    page: int | None = None
    index: int | None = None
    title: str | None = None
Attributes
file_id class-attribute instance-attribute
file_id = None
index class-attribute instance-attribute
index = None
page class-attribute instance-attribute
page = None
title class-attribute instance-attribute
title = None
url class-attribute instance-attribute
url = None
AudioBlock

Bases: BaseModel

Audio content block for messages.

Attributes:

Name Type Description
type Literal['audio']

Block type discriminator.

media MediaRef

Reference to audio media.

transcript str | None

Transcript of audio.

sample_rate int | None

Sample rate in Hz.

channels int | None

Number of audio channels.

Source code in pyagenity/utils/message.py
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
class AudioBlock(BaseModel):
    """
    Audio content block for messages.

    Attributes:
        type (Literal["audio"]): Block type discriminator.
        media (MediaRef): Reference to audio media.
        transcript (str | None): Transcript of audio.
        sample_rate (int | None): Sample rate in Hz.
        channels (int | None): Number of audio channels.
    """

    type: Literal["audio"] = "audio"
    media: MediaRef
    transcript: str | None = None
    sample_rate: int | None = None
    channels: int | None = None
Attributes
channels class-attribute instance-attribute
channels = None
media instance-attribute
media
sample_rate class-attribute instance-attribute
sample_rate = None
transcript class-attribute instance-attribute
transcript = None
type class-attribute instance-attribute
type = 'audio'
DataBlock

Bases: BaseModel

Data content block for messages.

Attributes:

Name Type Description
type Literal['data']

Block type discriminator.

mime_type str

MIME type of the data.

data_base64 str | None

Base64-encoded data.

media MediaRef | None

Reference to associated media.

Source code in pyagenity/utils/message.py
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
class DataBlock(BaseModel):
    """
    Data content block for messages.

    Attributes:
        type (Literal["data"]): Block type discriminator.
        mime_type (str): MIME type of the data.
        data_base64 (str | None): Base64-encoded data.
        media (MediaRef | None): Reference to associated media.
    """

    type: Literal["data"] = "data"
    mime_type: str
    data_base64: str | None = None
    media: MediaRef | None = None
Attributes
data_base64 class-attribute instance-attribute
data_base64 = None
media class-attribute instance-attribute
media = None
mime_type instance-attribute
mime_type
type class-attribute instance-attribute
type = 'data'
DocumentBlock

Bases: BaseModel

Document content block for messages.

Attributes:

Name Type Description
type Literal['document']

Block type discriminator.

media MediaRef

Reference to document media.

pages list[int] | None

List of page numbers.

excerpt str | None

Excerpt from the document.

Source code in pyagenity/utils/message.py
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
class DocumentBlock(BaseModel):
    """
    Document content block for messages.

    Attributes:
        type (Literal["document"]): Block type discriminator.
        media (MediaRef): Reference to document media.
        pages (list[int] | None): List of page numbers.
        excerpt (str | None): Excerpt from the document.
    """

    type: Literal["document"] = "document"
    media: MediaRef
    pages: list[int] | None = None
    excerpt: str | None = None
Attributes
excerpt class-attribute instance-attribute
excerpt = None
media instance-attribute
media
pages class-attribute instance-attribute
pages = None
type class-attribute instance-attribute
type = 'document'
ErrorBlock

Bases: BaseModel

Error content block for messages.

Attributes:

Name Type Description
type Literal['error']

Block type discriminator.

message str

Error message.

code str | None

Error code.

data dict[str, Any] | None

Additional error data.

Source code in pyagenity/utils/message.py
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
class ErrorBlock(BaseModel):
    """
    Error content block for messages.

    Attributes:
        type (Literal["error"]): Block type discriminator.
        message (str): Error message.
        code (str | None): Error code.
        data (dict[str, Any] | None): Additional error data.
    """

    type: Literal["error"] = "error"
    message: str
    code: str | None = None
    data: dict[str, Any] | None = None
Attributes
code class-attribute instance-attribute
code = None
data class-attribute instance-attribute
data = None
message instance-attribute
message
type class-attribute instance-attribute
type = 'error'
ImageBlock

Bases: BaseModel

Image content block for messages.

Attributes:

Name Type Description
type Literal['image']

Block type discriminator.

media MediaRef

Reference to image media.

alt_text str | None

Alternative text for accessibility.

bbox list[float] | None

Bounding box coordinates [x1, y1, x2, y2].

Source code in pyagenity/utils/message.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
class ImageBlock(BaseModel):
    """
    Image content block for messages.

    Attributes:
        type (Literal["image"]): Block type discriminator.
        media (MediaRef): Reference to image media.
        alt_text (str | None): Alternative text for accessibility.
        bbox (list[float] | None): Bounding box coordinates [x1, y1, x2, y2].
    """

    type: Literal["image"] = "image"
    media: MediaRef
    alt_text: str | None = None
    bbox: list[float] | None = None  # [x1,y1,x2,y2] if applicable
Attributes
alt_text class-attribute instance-attribute
alt_text = None
bbox class-attribute instance-attribute
bbox = None
media instance-attribute
media
type class-attribute instance-attribute
type = 'image'
MediaRef

Bases: BaseModel

Reference to media content (image/audio/video/document/data).

Prefer referencing by URL or provider file_id over inlining base64 for large payloads.

Attributes:

Name Type Description
kind Literal['url', 'file_id', 'data']

Type of reference.

url str | None

URL to media content.

file_id str | None

Provider-managed file ID.

data_base64 str | None

Base64-encoded data (small payloads only).

mime_type str | None

MIME type of the media.

size_bytes int | None

Size in bytes.

sha256 str | None

SHA256 hash of the media.

filename str | None

Filename of the media.

width int | None

Image width (if applicable).

height int | None

Image height (if applicable).

duration_ms int | None

Duration in milliseconds (if applicable).

page int | None

Page number (if applicable).

Source code in pyagenity/utils/message.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
class MediaRef(BaseModel):
    """
    Reference to media content (image/audio/video/document/data).

    Prefer referencing by URL or provider file_id over inlining base64 for large payloads.

    Attributes:
        kind (Literal["url", "file_id", "data"]): Type of reference.
        url (str | None): URL to media content.
        file_id (str | None): Provider-managed file ID.
        data_base64 (str | None): Base64-encoded data (small payloads only).
        mime_type (str | None): MIME type of the media.
        size_bytes (int | None): Size in bytes.
        sha256 (str | None): SHA256 hash of the media.
        filename (str | None): Filename of the media.
        width (int | None): Image width (if applicable).
        height (int | None): Image height (if applicable).
        duration_ms (int | None): Duration in milliseconds (if applicable).
        page (int | None): Page number (if applicable).
    """

    kind: Literal["url", "file_id", "data"] = "url"
    url: str | None = None  # http(s) or data: URL
    file_id: str | None = None  # provider-managed ID (e.g., OpenAI/Gemini)
    data_base64: str | None = None  # small payloads only
    mime_type: str | None = None
    size_bytes: int | None = None
    sha256: str | None = None
    filename: str | None = None
    # Media-specific hints
    width: int | None = None
    height: int | None = None
    duration_ms: int | None = None
    page: int | None = None
Attributes
data_base64 class-attribute instance-attribute
data_base64 = None
duration_ms class-attribute instance-attribute
duration_ms = None
file_id class-attribute instance-attribute
file_id = None
filename class-attribute instance-attribute
filename = None
height class-attribute instance-attribute
height = None
kind class-attribute instance-attribute
kind = 'url'
mime_type class-attribute instance-attribute
mime_type = None
page class-attribute instance-attribute
page = None
sha256 class-attribute instance-attribute
sha256 = None
size_bytes class-attribute instance-attribute
size_bytes = None
url class-attribute instance-attribute
url = None
width class-attribute instance-attribute
width = None
Message

Bases: BaseModel

Represents a message in a conversation, including content, role, metadata, and token usage.

Attributes:

Name Type Description
message_id str | int

Unique identifier for the message.

role Literal['user', 'assistant', 'system', 'tool']

The role of the message sender.

content list[ContentBlock]

The message content blocks.

delta bool

Indicates if this is a delta/partial message.

tools_calls list[dict[str, Any]] | None

Tool call information, if any.

reasoning str | None

Reasoning or explanation, if any.

timestamp datetime | None

Timestamp of the message.

metadata dict[str, Any]

Additional metadata.

usages TokenUsages | None

Token usage statistics.

raw dict[str, Any] | None

Raw data, if any.

Example

msg = Message(message_id="abc123", role="user", content=[TextBlock(text="Hello!")])

Methods:

Name Description
attach_media

Append a media block to the content.

text

Best-effort text extraction from content blocks.

text_message

Create a Message instance from plain text.

tool_message

Create a tool message, optionally marking it as an error.

Source code in pyagenity/utils/message.py
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
class Message(BaseModel):
    """
    Represents a message in a conversation, including content, role, metadata, and token usage.

    Attributes:
        message_id (str | int): Unique identifier for the message.
        role (Literal["user", "assistant", "system", "tool"]): The role of the message sender.
        content (list[ContentBlock]): The message content blocks.
        delta (bool): Indicates if this is a delta/partial message.
        tools_calls (list[dict[str, Any]] | None): Tool call information, if any.
        reasoning (str | None): Reasoning or explanation, if any.
        timestamp (datetime | None): Timestamp of the message.
        metadata (dict[str, Any]): Additional metadata.
        usages (TokenUsages | None): Token usage statistics.
        raw (dict[str, Any] | None): Raw data, if any.

    Example:
        >>> msg = Message(message_id="abc123", role="user", content=[TextBlock(text="Hello!")])
        {'message_id': 'abc123', 'role': 'user', 'content': [...], ...}
    """

    message_id: str | int = Field(default_factory=lambda: generate_id(None))
    role: Literal["user", "assistant", "system", "tool"]
    content: list[ContentBlock]
    delta: bool = False  # Indicates if this is a delta/partial message
    tools_calls: list[dict[str, Any]] | None = None
    reasoning: str | None = None  # Remove it
    timestamp: datetime | None = Field(default_factory=datetime.now)
    metadata: dict[str, Any] = Field(default_factory=dict)
    usages: TokenUsages | None = None
    raw: dict[str, Any] | None = None

    @classmethod
    def text_message(
        cls,
        content: str,
        role: Literal["user", "assistant", "system", "tool"] = "user",
        message_id: str | None = None,
    ) -> "Message":
        """
        Create a Message instance from plain text.

        Args:
            content (str): The message content.
            role (Literal["user", "assistant", "system", "tool"]): The role of the sender.
            message_id (str | None): Optional message ID.

        Returns:
            Message: The created Message instance.

        Example:
            >>> Message.text_message("Hello!", role="user")
        """
        logger.debug("Creating message from text with role: %s", role)
        return cls(
            message_id=generate_id(message_id),
            role=role,
            content=[TextBlock(text=content)],
            timestamp=datetime.now(),
            metadata={},
        )

    @classmethod
    def tool_message(
        cls,
        content: list[ContentBlock],
        message_id: str | None = None,
        meta: dict[str, Any] | None = None,
    ) -> "Message":
        """
        Create a tool message, optionally marking it as an error.

        Args:
            content (list[ContentBlock]): The message content blocks.
            message_id (str | None): Optional message ID.
            meta (dict[str, Any] | None): Optional metadata.

        Returns:
            Message: The created tool message instance.

        Example:
            >>> Message.tool_message([ToolResultBlock(...)], message_id="tool1")
        """
        res = content
        msg_id = generate_id(message_id)
        return cls(
            message_id=msg_id,
            role="tool",
            content=res,
            timestamp=datetime.now(),
            metadata=meta or {},
        )

    # --- Convenience helpers ---
    def text(self) -> str:
        """
        Best-effort text extraction from content blocks.

        Returns:
            str: Concatenated text from TextBlock and ToolResultBlock outputs.

        Example:
            >>> msg.text()
            'Hello!Result text.'
        """
        parts: list[str] = []
        for block in self.content:
            if isinstance(block, TextBlock):
                parts.append(block.text)
            elif isinstance(block, ToolResultBlock) and isinstance(block.output, str):
                parts.append(block.output)
        return "".join(parts)

    def attach_media(
        self,
        media: MediaRef,
        as_type: Literal["image", "audio", "video", "document"],
    ) -> None:
        """
        Append a media block to the content.

        If content was text, creates a block list. Supports image, audio, video, and document types.

        Args:
            media (MediaRef): Reference to media content.
            as_type (Literal["image", "audio", "video", "document"]): Type of media block to append.

        Returns:
            None

        Raises:
            ValueError: If an unsupported media type is provided.

        Example:
            >>> msg.attach_media(media_ref, as_type="image")
        """
        block: ContentBlock
        if as_type == "image":
            block = ImageBlock(media=media)
        elif as_type == "audio":
            block = AudioBlock(media=media)
        elif as_type == "video":
            block = VideoBlock(media=media)
        elif as_type == "document":
            block = DocumentBlock(media=media)
        else:
            raise ValueError(f"Unsupported media type: {as_type}")

        if isinstance(self.content, str):
            self.content = [TextBlock(text=self.content), block]
        elif isinstance(self.content, list):
            self.content.append(block)
        else:
            self.content = [block]
Attributes
content instance-attribute
content
delta class-attribute instance-attribute
delta = False
message_id class-attribute instance-attribute
message_id = Field(default_factory=lambda: generate_id(None))
metadata class-attribute instance-attribute
metadata = Field(default_factory=dict)
raw class-attribute instance-attribute
raw = None
reasoning class-attribute instance-attribute
reasoning = None
role instance-attribute
role
timestamp class-attribute instance-attribute
timestamp = Field(default_factory=now)
tools_calls class-attribute instance-attribute
tools_calls = None
usages class-attribute instance-attribute
usages = None
Functions
attach_media
attach_media(media, as_type)

Append a media block to the content.

If content was text, creates a block list. Supports image, audio, video, and document types.

Parameters:

Name Type Description Default
media MediaRef

Reference to media content.

required
as_type Literal['image', 'audio', 'video', 'document']

Type of media block to append.

required

Returns:

Type Description
None

None

Raises:

Type Description
ValueError

If an unsupported media type is provided.

Example

msg.attach_media(media_ref, as_type="image")

Source code in pyagenity/utils/message.py
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
def attach_media(
    self,
    media: MediaRef,
    as_type: Literal["image", "audio", "video", "document"],
) -> None:
    """
    Append a media block to the content.

    If content was text, creates a block list. Supports image, audio, video, and document types.

    Args:
        media (MediaRef): Reference to media content.
        as_type (Literal["image", "audio", "video", "document"]): Type of media block to append.

    Returns:
        None

    Raises:
        ValueError: If an unsupported media type is provided.

    Example:
        >>> msg.attach_media(media_ref, as_type="image")
    """
    block: ContentBlock
    if as_type == "image":
        block = ImageBlock(media=media)
    elif as_type == "audio":
        block = AudioBlock(media=media)
    elif as_type == "video":
        block = VideoBlock(media=media)
    elif as_type == "document":
        block = DocumentBlock(media=media)
    else:
        raise ValueError(f"Unsupported media type: {as_type}")

    if isinstance(self.content, str):
        self.content = [TextBlock(text=self.content), block]
    elif isinstance(self.content, list):
        self.content.append(block)
    else:
        self.content = [block]
text
text()

Best-effort text extraction from content blocks.

Returns:

Name Type Description
str str

Concatenated text from TextBlock and ToolResultBlock outputs.

Example

msg.text() 'Hello!Result text.'

Source code in pyagenity/utils/message.py
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
def text(self) -> str:
    """
    Best-effort text extraction from content blocks.

    Returns:
        str: Concatenated text from TextBlock and ToolResultBlock outputs.

    Example:
        >>> msg.text()
        'Hello!Result text.'
    """
    parts: list[str] = []
    for block in self.content:
        if isinstance(block, TextBlock):
            parts.append(block.text)
        elif isinstance(block, ToolResultBlock) and isinstance(block.output, str):
            parts.append(block.output)
    return "".join(parts)
text_message classmethod
text_message(content, role='user', message_id=None)

Create a Message instance from plain text.

Parameters:

Name Type Description Default
content str

The message content.

required
role Literal['user', 'assistant', 'system', 'tool']

The role of the sender.

'user'
message_id str | None

Optional message ID.

None

Returns:

Name Type Description
Message Message

The created Message instance.

Example

Message.text_message("Hello!", role="user")

Source code in pyagenity/utils/message.py
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
@classmethod
def text_message(
    cls,
    content: str,
    role: Literal["user", "assistant", "system", "tool"] = "user",
    message_id: str | None = None,
) -> "Message":
    """
    Create a Message instance from plain text.

    Args:
        content (str): The message content.
        role (Literal["user", "assistant", "system", "tool"]): The role of the sender.
        message_id (str | None): Optional message ID.

    Returns:
        Message: The created Message instance.

    Example:
        >>> Message.text_message("Hello!", role="user")
    """
    logger.debug("Creating message from text with role: %s", role)
    return cls(
        message_id=generate_id(message_id),
        role=role,
        content=[TextBlock(text=content)],
        timestamp=datetime.now(),
        metadata={},
    )
tool_message classmethod
tool_message(content, message_id=None, meta=None)

Create a tool message, optionally marking it as an error.

Parameters:

Name Type Description Default
content list[ContentBlock]

The message content blocks.

required
message_id str | None

Optional message ID.

None
meta dict[str, Any] | None

Optional metadata.

None

Returns:

Name Type Description
Message Message

The created tool message instance.

Example

Message.tool_message([ToolResultBlock(...)], message_id="tool1")

Source code in pyagenity/utils/message.py
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
@classmethod
def tool_message(
    cls,
    content: list[ContentBlock],
    message_id: str | None = None,
    meta: dict[str, Any] | None = None,
) -> "Message":
    """
    Create a tool message, optionally marking it as an error.

    Args:
        content (list[ContentBlock]): The message content blocks.
        message_id (str | None): Optional message ID.
        meta (dict[str, Any] | None): Optional metadata.

    Returns:
        Message: The created tool message instance.

    Example:
        >>> Message.tool_message([ToolResultBlock(...)], message_id="tool1")
    """
    res = content
    msg_id = generate_id(message_id)
    return cls(
        message_id=msg_id,
        role="tool",
        content=res,
        timestamp=datetime.now(),
        metadata=meta or {},
    )
ReasoningBlock

Bases: BaseModel

Reasoning content block for messages.

Attributes:

Name Type Description
type Literal['reasoning']

Block type discriminator.

summary str

Summary of reasoning.

details list[str] | None

Detailed reasoning steps.

Source code in pyagenity/utils/message.py
314
315
316
317
318
319
320
321
322
323
324
325
326
class ReasoningBlock(BaseModel):
    """
    Reasoning content block for messages.

    Attributes:
        type (Literal["reasoning"]): Block type discriminator.
        summary (str): Summary of reasoning.
        details (list[str] | None): Detailed reasoning steps.
    """

    type: Literal["reasoning"] = "reasoning"
    summary: str
    details: list[str] | None = None
Attributes
details class-attribute instance-attribute
details = None
summary instance-attribute
summary
type class-attribute instance-attribute
type = 'reasoning'
TextBlock

Bases: BaseModel

Text content block for messages.

Attributes:

Name Type Description
type Literal['text']

Block type discriminator.

text str

Text content.

annotations list[AnnotationRef]

List of annotation references.

Source code in pyagenity/utils/message.py
176
177
178
179
180
181
182
183
184
185
186
187
188
class TextBlock(BaseModel):
    """
    Text content block for messages.

    Attributes:
        type (Literal["text"]): Block type discriminator.
        text (str): Text content.
        annotations (list[AnnotationRef]): List of annotation references.
    """

    type: Literal["text"] = "text"
    text: str
    annotations: list[AnnotationRef] = Field(default_factory=list)
Attributes
annotations class-attribute instance-attribute
annotations = Field(default_factory=list)
text instance-attribute
text
type class-attribute instance-attribute
type = 'text'
TokenUsages

Bases: BaseModel

Tracks token usage statistics for a message or model response.

Attributes:

Name Type Description
completion_tokens int

Number of completion tokens used.

prompt_tokens int

Number of prompt tokens used.

total_tokens int

Total tokens used.

reasoning_tokens int

Reasoning tokens used (optional).

cache_creation_input_tokens int

Cache creation input tokens (optional).

cache_read_input_tokens int

Cache read input tokens (optional).

image_tokens int | None

Image tokens for multimodal models (optional).

audio_tokens int | None

Audio tokens for multimodal models (optional).

Example

usage = TokenUsages(completion_tokens=10, prompt_tokens=20, total_tokens=30)

Source code in pyagenity/utils/message.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
class TokenUsages(BaseModel):
    """
    Tracks token usage statistics for a message or model response.

    Attributes:
        completion_tokens (int): Number of completion tokens used.
        prompt_tokens (int): Number of prompt tokens used.
        total_tokens (int): Total tokens used.
        reasoning_tokens (int): Reasoning tokens used (optional).
        cache_creation_input_tokens (int): Cache creation input tokens (optional).
        cache_read_input_tokens (int): Cache read input tokens (optional).
        image_tokens (int | None): Image tokens for multimodal models (optional).
        audio_tokens (int | None): Audio tokens for multimodal models (optional).

    Example:
        >>> usage = TokenUsages(completion_tokens=10, prompt_tokens=20, total_tokens=30)
        {'completion_tokens': 10, 'prompt_tokens': 20, 'total_tokens': 30, ...}
    """

    completion_tokens: int
    prompt_tokens: int
    total_tokens: int
    reasoning_tokens: int = 0
    cache_creation_input_tokens: int = 0
    cache_read_input_tokens: int = 0
    # Optional modality-specific usage fields for multimodal models
    image_tokens: int | None = 0
    audio_tokens: int | None = 0
Attributes
audio_tokens class-attribute instance-attribute
audio_tokens = 0
cache_creation_input_tokens class-attribute instance-attribute
cache_creation_input_tokens = 0
cache_read_input_tokens class-attribute instance-attribute
cache_read_input_tokens = 0
completion_tokens instance-attribute
completion_tokens
image_tokens class-attribute instance-attribute
image_tokens = 0
prompt_tokens instance-attribute
prompt_tokens
reasoning_tokens class-attribute instance-attribute
reasoning_tokens = 0
total_tokens instance-attribute
total_tokens
ToolCallBlock

Bases: BaseModel

Tool call content block for messages.

Attributes:

Name Type Description
type Literal['tool_call']

Block type discriminator.

id str

Tool call ID.

name str

Tool name.

args dict[str, Any]

Arguments for the tool call.

tool_type str | None

Type of tool (e.g., web_search, file_search).

Source code in pyagenity/utils/message.py
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
class ToolCallBlock(BaseModel):
    """
    Tool call content block for messages.

    Attributes:
        type (Literal["tool_call"]): Block type discriminator.
        id (str): Tool call ID.
        name (str): Tool name.
        args (dict[str, Any]): Arguments for the tool call.
        tool_type (str | None): Type of tool (e.g., web_search, file_search).
    """

    type: Literal["tool_call"] = "tool_call"
    id: str
    name: str
    args: dict[str, Any] = Field(default_factory=dict)
    tool_type: str | None = None  # e.g., web_search, file_search, computer_use
Attributes
args class-attribute instance-attribute
args = Field(default_factory=dict)
id instance-attribute
id
name instance-attribute
name
tool_type class-attribute instance-attribute
tool_type = None
type class-attribute instance-attribute
type = 'tool_call'
ToolResultBlock

Bases: BaseModel

Tool result content block for messages.

Attributes:

Name Type Description
type Literal['tool_result']

Block type discriminator.

call_id str

Tool call ID.

output Any

Output from the tool (str, dict, MediaRef, or list of blocks).

is_error bool

Whether the result is an error.

status Literal['completed', 'failed'] | None

Status of the tool call.

Source code in pyagenity/utils/message.py
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
class ToolResultBlock(BaseModel):
    """
    Tool result content block for messages.

    Attributes:
        type (Literal["tool_result"]): Block type discriminator.
        call_id (str): Tool call ID.
        output (Any): Output from the tool (str, dict, MediaRef, or list of blocks).
        is_error (bool): Whether the result is an error.
        status (Literal["completed", "failed"] | None): Status of the tool call.
    """

    type: Literal["tool_result"] = "tool_result"
    call_id: str
    output: Any = None  # str | dict | MediaRef | list[ContentBlock-like]
    is_error: bool = False
    status: Literal["completed", "failed"] | None = None
Attributes
call_id instance-attribute
call_id
is_error class-attribute instance-attribute
is_error = False
output class-attribute instance-attribute
output = None
status class-attribute instance-attribute
status = None
type class-attribute instance-attribute
type = 'tool_result'
VideoBlock

Bases: BaseModel

Video content block for messages.

Attributes:

Name Type Description
type Literal['video']

Block type discriminator.

media MediaRef

Reference to video media.

thumbnail MediaRef | None

Reference to thumbnail image.

Source code in pyagenity/utils/message.py
227
228
229
230
231
232
233
234
235
236
237
238
239
class VideoBlock(BaseModel):
    """
    Video content block for messages.

    Attributes:
        type (Literal["video"]): Block type discriminator.
        media (MediaRef): Reference to video media.
        thumbnail (MediaRef | None): Reference to thumbnail image.
    """

    type: Literal["video"] = "video"
    media: MediaRef
    thumbnail: MediaRef | None = None
Attributes
media instance-attribute
media
thumbnail class-attribute instance-attribute
thumbnail = None
type class-attribute instance-attribute
type = 'video'
Functions
generate_id
generate_id(default_id)

Generate a message or tool call ID based on DI context and type.

Parameters:

Name Type Description Default
default_id str | int | None

Default ID to use if provided and matches type.

required

Returns:

Type Description
str | int

str | int: Generated or provided ID, type determined by DI context.

Example

generate_id("abc123") 'abc123' generate_id(None) 'a-uuid-string'

Source code in pyagenity/utils/message.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def generate_id(default_id: str | int | None) -> str | int:
    """
    Generate a message or tool call ID based on DI context and type.

    Args:
        default_id (str | int | None): Default ID to use if provided and matches type.

    Returns:
        str | int: Generated or provided ID, type determined by DI context.

    Raises:
        None

    Example:
        >>> generate_id("abc123")
        'abc123'
        >>> generate_id(None)
        'a-uuid-string'
    """
    id_type = InjectQ.get_instance().try_get("generated_id_type", "string")
    generated_id = InjectQ.get_instance().try_get("generated_id", None)

    # if user provided an awaitable, resolve it
    if isinstance(generated_id, Awaitable):

        async def wait_for_id():
            return await generated_id

        generated_id = asyncio.run(wait_for_id())

    if generated_id:
        return generated_id

    if default_id:
        if id_type == "string" and isinstance(default_id, str):
            return default_id
        if id_type in ("int", "bigint") and isinstance(default_id, int):
            return default_id

    # if not matched or default_id is None, generate new id
    logger.debug(
        "Generating new id of type: %s. Default ID not provided or not matched %s",
        id_type,
        default_id,
    )

    if id_type == "int":
        return uuid4().int >> 32
    if id_type == "bigint":
        return uuid4().int >> 64
    return str(uuid4())
metrics

Lightweight metrics instrumentation utilities.

Design goals
  • Zero dependency by default.
  • Cheap no-op when disabled.
  • Pluggable exporter (e.g., Prometheus scrape formatting) later.
Usage

from pyagenity.utils.metrics import counter, timer counter('messages_written_total').inc() with timer('db_write_latency_ms'): ...

Classes:

Name Description
Counter
TimerMetric

Functions:

Name Description
counter
enable_metrics
snapshot

Return a point-in-time snapshot of metrics (thread-safe copy).

timer
Classes
Counter dataclass

Methods:

Name Description
__init__
inc

Attributes:

Name Type Description
name str
value int
Source code in pyagenity/utils/metrics.py
34
35
36
37
38
39
40
41
42
43
@dataclass
class Counter:
    name: str
    value: int = 0

    def inc(self, amount: int = 1) -> None:
        if not _ENABLED:
            return
        with _LOCK:
            self.value += amount
Attributes
name instance-attribute
name
value class-attribute instance-attribute
value = 0
Functions
__init__
__init__(name, value=0)
inc
inc(amount=1)
Source code in pyagenity/utils/metrics.py
39
40
41
42
43
def inc(self, amount: int = 1) -> None:
    if not _ENABLED:
        return
    with _LOCK:
        self.value += amount
TimerMetric dataclass

Methods:

Name Description
__init__
observe

Attributes:

Name Type Description
avg_ms float
count int
max_ms float
name str
total_ms float
Source code in pyagenity/utils/metrics.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
@dataclass
class TimerMetric:
    name: str
    count: int = 0
    total_ms: float = 0.0
    max_ms: float = 0.0

    def observe(self, duration_ms: float) -> None:
        if not _ENABLED:
            return
        with _LOCK:
            self.count += 1
            self.total_ms += duration_ms
            self.max_ms = max(self.max_ms, duration_ms)

    @property
    def avg_ms(self) -> float:
        if self.count == 0:
            return 0.0
        return self.total_ms / self.count
Attributes
avg_ms property
avg_ms
count class-attribute instance-attribute
count = 0
max_ms class-attribute instance-attribute
max_ms = 0.0
name instance-attribute
name
total_ms class-attribute instance-attribute
total_ms = 0.0
Functions
__init__
__init__(name, count=0, total_ms=0.0, max_ms=0.0)
observe
observe(duration_ms)
Source code in pyagenity/utils/metrics.py
53
54
55
56
57
58
59
def observe(self, duration_ms: float) -> None:
    if not _ENABLED:
        return
    with _LOCK:
        self.count += 1
        self.total_ms += duration_ms
        self.max_ms = max(self.max_ms, duration_ms)
Functions
counter
counter(name)
Source code in pyagenity/utils/metrics.py
68
69
70
71
72
73
74
def counter(name: str) -> Counter:
    with _LOCK:
        c = _COUNTERS.get(name)
        if c is None:
            c = Counter(name)
            _COUNTERS[name] = c
        return c
enable_metrics
enable_metrics(value)
Source code in pyagenity/utils/metrics.py
29
30
31
def enable_metrics(value: bool) -> None:  # simple toggle; acceptable global
    # Intentionally keeps a module-level switch—call sites cheap check.
    globals()["_ENABLED"] = value
snapshot
snapshot()

Return a point-in-time snapshot of metrics (thread-safe copy).

Source code in pyagenity/utils/metrics.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def snapshot() -> dict:
    """Return a point-in-time snapshot of metrics (thread-safe copy)."""
    with _LOCK:
        return {
            "counters": {k: v.value for k, v in _COUNTERS.items()},
            "timers": {
                k: {
                    "count": t.count,
                    "avg_ms": t.avg_ms,
                    "max_ms": t.max_ms,
                }
                for k, t in _TIMERS.items()
            },
        }
timer
timer(name)
Source code in pyagenity/utils/metrics.py
77
78
79
80
81
82
83
84
85
def timer(name: str) -> _TimerCtx:  # convenience factory
    metric = _TIMERS.get(name)
    if metric is None:
        with _LOCK:
            metric = _TIMERS.get(name)
            if metric is None:
                metric = TimerMetric(name)
                _TIMERS[name] = metric
    return _TimerCtx(metric)
reducers

Reducer utilities for merging and replacing lists and values in agent state.

This module provides generic and message-specific reducers for combining lists, replacing values, and appending items while avoiding duplicates.

Functions:

Name Description
add_messages

Adds messages to a list, avoiding duplicates by message_id.

replace_messages

Replaces the entire message list.

append_items

Appends items to a list, avoiding duplicates by id.

replace_value

Replaces a value with another.

Classes
Functions
add_messages
add_messages(left, right)

Adds messages to the list, avoiding duplicates by message_id.

Parameters:

Name Type Description Default
left list[Message]

Existing list of messages.

required
right list[Message]

New messages to add.

required

Returns:

Type Description
list[Message]

list[Message]: Combined list with unique messages.

Example

add_messages([msg1], [msg2, msg1]) [msg1, msg2]

Source code in pyagenity/utils/reducers.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def add_messages(left: list[Message], right: list[Message]) -> list[Message]:
    """
    Adds messages to the list, avoiding duplicates by message_id.

    Args:
        left (list[Message]): Existing list of messages.
        right (list[Message]): New messages to add.

    Returns:
        list[Message]: Combined list with unique messages.

    Example:
        >>> add_messages([msg1], [msg2, msg1])
        [msg1, msg2]
    """
    left_ids = {msg.message_id for msg in left}
    right = [msg for msg in right if msg.message_id not in left_ids]
    return left + right
append_items
append_items(left, right)

Appends items to a list, avoiding duplicates by item.id.

Parameters:

Name Type Description Default
left list

Existing list of items (must have .id attribute).

required
right list

New items to add.

required

Returns:

Name Type Description
list list

Combined list with unique items.

Example

append_items([item1], [item2, item1]) [item1, item2]

Source code in pyagenity/utils/reducers.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def append_items(left: list, right: list) -> list:
    """
    Appends items to a list, avoiding duplicates by item.id.

    Args:
        left (list): Existing list of items (must have .id attribute).
        right (list): New items to add.

    Returns:
        list: Combined list with unique items.

    Example:
        >>> append_items([item1], [item2, item1])
        [item1, item2]
    """
    left_ids = {item.id for item in left}
    right = [item for item in right if item.id not in left_ids]
    return left + right
replace_messages
replace_messages(left, right)

Replaces the entire message list with a new one.

Parameters:

Name Type Description Default
left list[Message]

Existing list of messages (ignored).

required
right list[Message]

New list of messages.

required

Returns:

Type Description
list[Message]

list[Message]: The new message list.

Example

replace_messages([msg1], [msg2]) [msg2]

Source code in pyagenity/utils/reducers.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def replace_messages(left: list[Message], right: list[Message]) -> list[Message]:
    """
    Replaces the entire message list with a new one.

    Args:
        left (list[Message]): Existing list of messages (ignored).
        right (list[Message]): New list of messages.

    Returns:
        list[Message]: The new message list.

    Example:
        >>> replace_messages([msg1], [msg2])
        [msg2]
    """
    return right
replace_value
replace_value(left, right)

Replaces a value with another.

Parameters:

Name Type Description Default
left

Existing value (ignored).

required
right

New value to use.

required

Returns:

Name Type Description
Any

The new value.

Example

replace_value(1, 2) 2

Source code in pyagenity/utils/reducers.py
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def replace_value(left, right):
    """
    Replaces a value with another.

    Args:
        left: Existing value (ignored).
        right: New value to use.

    Returns:
        Any: The new value.

    Example:
        >>> replace_value(1, 2)
        2
    """
    return right
thread_info

Thread metadata and status tracking for agent graphs.

This module defines the ThreadInfo model, which tracks thread identity, user, metadata, status, and timestamps for agent graph execution and orchestration.

Classes:

Name Description
ThreadInfo

Metadata and status for a thread in agent execution.

Classes
ThreadInfo

Bases: BaseModel

Metadata and status for a thread in agent execution.

Attributes:

Name Type Description
thread_id int | str

Unique identifier for the thread.

thread_name str | None

Optional name for the thread.

user_id int | str | None

Optional user identifier associated with the thread.

metadata dict[str, Any] | None

Optional metadata for the thread.

updated_at datetime | None

Timestamp of last update.

stop_requested bool

Whether a stop has been requested for the thread.

run_id str | None

Optional run identifier for the thread execution.

Example

ThreadInfo(thread_id=1, thread_name="main", user_id=42)

Source code in pyagenity/utils/thread_info.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class ThreadInfo(BaseModel):
    """
    Metadata and status for a thread in agent execution.

    Attributes:
        thread_id (int | str): Unique identifier for the thread.
        thread_name (str | None): Optional name for the thread.
        user_id (int | str | None): Optional user identifier associated with the thread.
        metadata (dict[str, Any] | None): Optional metadata for the thread.
        updated_at (datetime | None): Timestamp of last update.
        stop_requested (bool): Whether a stop has been requested for the thread.
        run_id (str | None): Optional run identifier for the thread execution.

    Example:
        >>> ThreadInfo(thread_id=1, thread_name="main", user_id=42)
    """

    thread_id: int | str
    thread_name: str | None = None
    user_id: int | str | None = None
    metadata: dict[str, Any] | None = None
    updated_at: datetime | None = None
    run_id: str | None = None
Attributes
metadata class-attribute instance-attribute
metadata = None
run_id class-attribute instance-attribute
run_id = None
thread_id instance-attribute
thread_id
thread_name class-attribute instance-attribute
thread_name = None
updated_at class-attribute instance-attribute
updated_at = None
user_id class-attribute instance-attribute
user_id = None
thread_name_generator

Thread name generation utilities for AI agent conversations.

This module provides the AIThreadNameGenerator class and helper function for generating meaningful, varied, and human-friendly thread names for AI chat sessions using different patterns and themes.

Classes:

Name Description
AIThreadNameGenerator

Generates thread names using adjective-noun, action-based,

Functions:

Name Description
generate_dummy_thread_name

Convenience function for generating a thread name.

Classes
AIThreadNameGenerator

Generates meaningful, varied thread names for AI conversations using different patterns and themes. Patterns include adjective-noun, action-based, and compound descriptive names.

Example

AIThreadNameGenerator().generate_name() 'thoughtful-dialogue'

Methods:

Name Description
generate_action_name

Generate an action-based thread name for a more dynamic feel.

generate_compound_name

Generate a compound descriptive thread name.

generate_name

Generate a meaningful thread name using random pattern selection.

generate_simple_name

Generate a simple adjective-noun combination for a thread name.

Attributes:

Name Type Description
ACTION_PATTERNS
ADJECTIVES
COMPOUND_PATTERNS
NOUNS
Source code in pyagenity/utils/thread_name_generator.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
class AIThreadNameGenerator:
    """
    Generates meaningful, varied thread names for AI conversations using different
    patterns and themes. Patterns include adjective-noun, action-based, and compound
    descriptive names.

    Example:
        >>> AIThreadNameGenerator().generate_name()
        'thoughtful-dialogue'
    """

    # Enhanced adjectives grouped by semantic meaning
    ADJECTIVES = [
        # Intellectual
        "thoughtful",
        "insightful",
        "analytical",
        "logical",
        "strategic",
        "methodical",
        "systematic",
        "comprehensive",
        "detailed",
        "precise",
        # Creative
        "creative",
        "imaginative",
        "innovative",
        "artistic",
        "expressive",
        "original",
        "inventive",
        "inspired",
        "visionary",
        "whimsical",
        # Emotional/Social
        "engaging",
        "collaborative",
        "meaningful",
        "productive",
        "harmonious",
        "enlightening",
        "empathetic",
        "supportive",
        "encouraging",
        "uplifting",
        # Dynamic
        "dynamic",
        "energetic",
        "vibrant",
        "lively",
        "spirited",
        "active",
        "flowing",
        "adaptive",
        "responsive",
        "interactive",
        # Quality-focused
        "focused",
        "dedicated",
        "thorough",
        "meticulous",
        "careful",
        "patient",
        "persistent",
        "resilient",
        "determined",
        "ambitious",
    ]

    # Enhanced nouns with more conversational context
    NOUNS = [
        # Conversation-related
        "dialogue",
        "conversation",
        "discussion",
        "exchange",
        "chat",
        "consultation",
        "session",
        "meeting",
        "interaction",
        "communication",
        # Journey/Process
        "journey",
        "exploration",
        "adventure",
        "quest",
        "voyage",
        "expedition",
        "discovery",
        "investigation",
        "research",
        "study",
        # Conceptual
        "insight",
        "vision",
        "perspective",
        "understanding",
        "wisdom",
        "knowledge",
        "learning",
        "growth",
        "development",
        "progress",
        # Solution-oriented
        "solution",
        "approach",
        "strategy",
        "method",
        "framework",
        "plan",
        "blueprint",
        "pathway",
        "route",
        "direction",
        # Creative/Abstract
        "canvas",
        "story",
        "narrative",
        "symphony",
        "composition",
        "creation",
        "masterpiece",
        "design",
        "pattern",
        "concept",
        # Collaborative
        "partnership",
        "collaboration",
        "alliance",
        "connection",
        "bond",
        "synergy",
        "harmony",
        "unity",
        "cooperation",
        "teamwork",
    ]

    # Action-based patterns for more dynamic names
    ACTION_PATTERNS = {
        "exploring": ["ideas", "concepts", "possibilities", "mysteries", "frontiers", "depths"],
        "building": ["solutions", "understanding", "connections", "frameworks", "bridges"],
        "discovering": ["insights", "patterns", "answers", "truths", "secrets", "wisdom"],
        "crafting": ["responses", "solutions", "stories", "strategies", "experiences"],
        "navigating": ["challenges", "questions", "complexities", "territories", "paths"],
        "unlocking": ["potential", "mysteries", "possibilities", "creativity", "knowledge"],
        "weaving": ["ideas", "stories", "connections", "patterns", "narratives"],
        "illuminating": ["concepts", "mysteries", "paths", "truths", "possibilities"],
    }

    # Descriptive compound patterns
    COMPOUND_PATTERNS = [
        ("deep", ["dive", "thought", "reflection", "analysis", "exploration"]),
        ("bright", ["spark", "idea", "insight", "moment", "flash"]),
        ("fresh", ["perspective", "approach", "start", "take", "view"]),
        ("open", ["dialogue", "discussion", "conversation", "exchange", "forum"]),
        ("creative", ["flow", "spark", "burst", "stream", "wave"]),
        ("mindful", ["moment", "pause", "reflection", "consideration", "thought"]),
        ("collaborative", ["effort", "venture", "journey", "exploration", "creation"]),
    ]

    def generate_simple_name(self, separator: str = "-") -> str:
        """
        Generate a simple adjective-noun combination for a thread name.

        Args:
            separator (str): String to separate words (default: "-").

        Returns:
            str: Name like "thoughtful-dialogue" or "creative-exploration".

        Example:
            >>> AIThreadNameGenerator().generate_simple_name()
            'creative-exploration'
        """
        adj = secrets.choice(self.ADJECTIVES)
        noun = secrets.choice(self.NOUNS)
        return f"{adj}{separator}{noun}"

    def generate_action_name(self, separator: str = "-") -> str:
        """
        Generate an action-based thread name for a more dynamic feel.

        Args:
            separator (str): String to separate words (default: "-").

        Returns:
            str: Name like "exploring-ideas" or "building-understanding".

        Example:
            >>> AIThreadNameGenerator().generate_action_name()
            'building-connections'
        """
        action = secrets.choice(list(self.ACTION_PATTERNS.keys()))
        target = secrets.choice(self.ACTION_PATTERNS[action])
        return f"{action}{separator}{target}"

    def generate_compound_name(self, separator: str = "-") -> str:
        """
        Generate a compound descriptive thread name.

        Args:
            separator (str): String to separate words (default: "-").

        Returns:
            str: Name like "deep-dive" or "bright-spark".

        Example:
            >>> AIThreadNameGenerator().generate_compound_name()
            'deep-reflection'
        """
        base, options = secrets.choice(self.COMPOUND_PATTERNS)
        complement = secrets.choice(options)
        return f"{base}{separator}{complement}"

    def generate_name(self, separator: str = "-") -> str:
        """
        Generate a meaningful thread name using random pattern selection.

        Args:
            separator (str): String to separate words (default: "-").

        Returns:
            str: A meaningful thread name from various patterns.

        Example:
            >>> AIThreadNameGenerator().generate_name()
            'engaging-discussion'
        """
        # Randomly choose between different naming patterns
        pattern = secrets.choice(["simple", "action", "compound"])

        if pattern == "simple":
            return self.generate_simple_name(separator)
        if pattern == "action":
            return self.generate_action_name(separator)
        # compound
        return self.generate_compound_name(separator)
Attributes
ACTION_PATTERNS class-attribute instance-attribute
ACTION_PATTERNS = {'exploring': ['ideas', 'concepts', 'possibilities', 'mysteries', 'frontiers', 'depths'], 'building': ['solutions', 'understanding', 'connections', 'frameworks', 'bridges'], 'discovering': ['insights', 'patterns', 'answers', 'truths', 'secrets', 'wisdom'], 'crafting': ['responses', 'solutions', 'stories', 'strategies', 'experiences'], 'navigating': ['challenges', 'questions', 'complexities', 'territories', 'paths'], 'unlocking': ['potential', 'mysteries', 'possibilities', 'creativity', 'knowledge'], 'weaving': ['ideas', 'stories', 'connections', 'patterns', 'narratives'], 'illuminating': ['concepts', 'mysteries', 'paths', 'truths', 'possibilities']}
ADJECTIVES class-attribute instance-attribute
ADJECTIVES = ['thoughtful', 'insightful', 'analytical', 'logical', 'strategic', 'methodical', 'systematic', 'comprehensive', 'detailed', 'precise', 'creative', 'imaginative', 'innovative', 'artistic', 'expressive', 'original', 'inventive', 'inspired', 'visionary', 'whimsical', 'engaging', 'collaborative', 'meaningful', 'productive', 'harmonious', 'enlightening', 'empathetic', 'supportive', 'encouraging', 'uplifting', 'dynamic', 'energetic', 'vibrant', 'lively', 'spirited', 'active', 'flowing', 'adaptive', 'responsive', 'interactive', 'focused', 'dedicated', 'thorough', 'meticulous', 'careful', 'patient', 'persistent', 'resilient', 'determined', 'ambitious']
COMPOUND_PATTERNS class-attribute instance-attribute
COMPOUND_PATTERNS = [('deep', ['dive', 'thought', 'reflection', 'analysis', 'exploration']), ('bright', ['spark', 'idea', 'insight', 'moment', 'flash']), ('fresh', ['perspective', 'approach', 'start', 'take', 'view']), ('open', ['dialogue', 'discussion', 'conversation', 'exchange', 'forum']), ('creative', ['flow', 'spark', 'burst', 'stream', 'wave']), ('mindful', ['moment', 'pause', 'reflection', 'consideration', 'thought']), ('collaborative', ['effort', 'venture', 'journey', 'exploration', 'creation'])]
NOUNS class-attribute instance-attribute
NOUNS = ['dialogue', 'conversation', 'discussion', 'exchange', 'chat', 'consultation', 'session', 'meeting', 'interaction', 'communication', 'journey', 'exploration', 'adventure', 'quest', 'voyage', 'expedition', 'discovery', 'investigation', 'research', 'study', 'insight', 'vision', 'perspective', 'understanding', 'wisdom', 'knowledge', 'learning', 'growth', 'development', 'progress', 'solution', 'approach', 'strategy', 'method', 'framework', 'plan', 'blueprint', 'pathway', 'route', 'direction', 'canvas', 'story', 'narrative', 'symphony', 'composition', 'creation', 'masterpiece', 'design', 'pattern', 'concept', 'partnership', 'collaboration', 'alliance', 'connection', 'bond', 'synergy', 'harmony', 'unity', 'cooperation', 'teamwork']
Functions
generate_action_name
generate_action_name(separator='-')

Generate an action-based thread name for a more dynamic feel.

Parameters:

Name Type Description Default
separator str

String to separate words (default: "-").

'-'

Returns:

Name Type Description
str str

Name like "exploring-ideas" or "building-understanding".

Example

AIThreadNameGenerator().generate_action_name() 'building-connections'

Source code in pyagenity/utils/thread_name_generator.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
def generate_action_name(self, separator: str = "-") -> str:
    """
    Generate an action-based thread name for a more dynamic feel.

    Args:
        separator (str): String to separate words (default: "-").

    Returns:
        str: Name like "exploring-ideas" or "building-understanding".

    Example:
        >>> AIThreadNameGenerator().generate_action_name()
        'building-connections'
    """
    action = secrets.choice(list(self.ACTION_PATTERNS.keys()))
    target = secrets.choice(self.ACTION_PATTERNS[action])
    return f"{action}{separator}{target}"
generate_compound_name
generate_compound_name(separator='-')

Generate a compound descriptive thread name.

Parameters:

Name Type Description Default
separator str

String to separate words (default: "-").

'-'

Returns:

Name Type Description
str str

Name like "deep-dive" or "bright-spark".

Example

AIThreadNameGenerator().generate_compound_name() 'deep-reflection'

Source code in pyagenity/utils/thread_name_generator.py
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
def generate_compound_name(self, separator: str = "-") -> str:
    """
    Generate a compound descriptive thread name.

    Args:
        separator (str): String to separate words (default: "-").

    Returns:
        str: Name like "deep-dive" or "bright-spark".

    Example:
        >>> AIThreadNameGenerator().generate_compound_name()
        'deep-reflection'
    """
    base, options = secrets.choice(self.COMPOUND_PATTERNS)
    complement = secrets.choice(options)
    return f"{base}{separator}{complement}"
generate_name
generate_name(separator='-')

Generate a meaningful thread name using random pattern selection.

Parameters:

Name Type Description Default
separator str

String to separate words (default: "-").

'-'

Returns:

Name Type Description
str str

A meaningful thread name from various patterns.

Example

AIThreadNameGenerator().generate_name() 'engaging-discussion'

Source code in pyagenity/utils/thread_name_generator.py
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
def generate_name(self, separator: str = "-") -> str:
    """
    Generate a meaningful thread name using random pattern selection.

    Args:
        separator (str): String to separate words (default: "-").

    Returns:
        str: A meaningful thread name from various patterns.

    Example:
        >>> AIThreadNameGenerator().generate_name()
        'engaging-discussion'
    """
    # Randomly choose between different naming patterns
    pattern = secrets.choice(["simple", "action", "compound"])

    if pattern == "simple":
        return self.generate_simple_name(separator)
    if pattern == "action":
        return self.generate_action_name(separator)
    # compound
    return self.generate_compound_name(separator)
generate_simple_name
generate_simple_name(separator='-')

Generate a simple adjective-noun combination for a thread name.

Parameters:

Name Type Description Default
separator str

String to separate words (default: "-").

'-'

Returns:

Name Type Description
str str

Name like "thoughtful-dialogue" or "creative-exploration".

Example

AIThreadNameGenerator().generate_simple_name() 'creative-exploration'

Source code in pyagenity/utils/thread_name_generator.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
def generate_simple_name(self, separator: str = "-") -> str:
    """
    Generate a simple adjective-noun combination for a thread name.

    Args:
        separator (str): String to separate words (default: "-").

    Returns:
        str: Name like "thoughtful-dialogue" or "creative-exploration".

    Example:
        >>> AIThreadNameGenerator().generate_simple_name()
        'creative-exploration'
    """
    adj = secrets.choice(self.ADJECTIVES)
    noun = secrets.choice(self.NOUNS)
    return f"{adj}{separator}{noun}"
Functions
generate_dummy_thread_name
generate_dummy_thread_name(separator='-')

Generate a meaningful English name for an AI chat thread.

Parameters:

Name Type Description Default
separator str

String to separate words (default: "-").

'-'

Returns:

Name Type Description
str str

A meaningful thread name like 'thoughtful-dialogue', 'exploring-ideas', or 'deep-dive'.

Example

generate_dummy_thread_name() 'creative-exploration'

Source code in pyagenity/utils/thread_name_generator.py
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
def generate_dummy_thread_name(separator: str = "-") -> str:
    """
    Generate a meaningful English name for an AI chat thread.

    Args:
        separator (str): String to separate words (default: "-").

    Returns:
        str: A meaningful thread name like 'thoughtful-dialogue', 'exploring-ideas', or 'deep-dive'.

    Example:
        >>> generate_dummy_thread_name()
        'creative-exploration'
    """
    generator = AIThreadNameGenerator()
    return generator.generate_name(separator)