Skip to content

Guarded

Classes:

Name Description
GuardedAgent

Validate output and repair until valid or attempts exhausted.

Attributes:

Name Type Description
StateT

Attributes

StateT module-attribute

StateT = TypeVar('StateT', bound=AgentState)

Classes

GuardedAgent

Validate output and repair until valid or attempts exhausted.

Nodes: - PRODUCE: main generation node - REPAIR: correction node when validation fails

Edges: PRODUCE -> conditional(valid? END : REPAIR) REPAIR -> PRODUCE

Methods:

Name Description
__init__
compile
Source code in pyagenity/prebuilt/agent/guarded.py
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
class GuardedAgent[StateT: AgentState]:
    """Validate output and repair until valid or attempts exhausted.

    Nodes:
    - PRODUCE: main generation node
    - REPAIR: correction node when validation fails

    Edges:
    PRODUCE -> conditional(valid? END : REPAIR)
    REPAIR -> PRODUCE
    """

    def __init__(
        self,
        state: StateT | None = None,
        context_manager: BaseContextManager[StateT] | None = None,
        publisher: BasePublisher | None = None,
        id_generator: BaseIDGenerator = DefaultIDGenerator(),
        container: InjectQ | None = None,
    ):
        self._graph = StateGraph[StateT](
            state=state,
            context_manager=context_manager,
            publisher=publisher,
            id_generator=id_generator,
            container=container,
        )

    def compile(
        self,
        produce_node: Callable | tuple[Callable, str],
        repair_node: Callable | tuple[Callable, str],
        validator: Callable[[AgentState], bool],
        max_attempts: int = 2,
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> CompiledGraph:
        # Handle produce_node
        if isinstance(produce_node, tuple):
            produce_func, produce_name = produce_node
            if not callable(produce_func):
                raise ValueError("produce_node[0] must be callable")
        else:
            produce_func = produce_node
            produce_name = "PRODUCE"
            if not callable(produce_func):
                raise ValueError("produce_node must be callable")

        # Handle repair_node
        if isinstance(repair_node, tuple):
            repair_func, repair_name = repair_node
            if not callable(repair_func):
                raise ValueError("repair_node[0] must be callable")
        else:
            repair_func = repair_node
            repair_name = "REPAIR"
            if not callable(repair_func):
                raise ValueError("repair_node must be callable")

        self._graph.add_node(produce_name, produce_func)
        self._graph.add_node(repair_name, repair_func)

        # produce -> END or REPAIR
        condition = _guard_condition_factory(validator, max_attempts)
        self._graph.add_conditional_edges(
            produce_name,
            condition,
            {repair_name: repair_name, END: END},
        )
        # repair -> produce
        self._graph.add_edge(repair_name, produce_name)

        self._graph.set_entry_point(produce_name)

        return self._graph.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            callback_manager=callback_manager,
        )

Functions

__init__
__init__(state=None, context_manager=None, publisher=None, id_generator=DefaultIDGenerator(), container=None)
Source code in pyagenity/prebuilt/agent/guarded.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def __init__(
    self,
    state: StateT | None = None,
    context_manager: BaseContextManager[StateT] | None = None,
    publisher: BasePublisher | None = None,
    id_generator: BaseIDGenerator = DefaultIDGenerator(),
    container: InjectQ | None = None,
):
    self._graph = StateGraph[StateT](
        state=state,
        context_manager=context_manager,
        publisher=publisher,
        id_generator=id_generator,
        container=container,
    )
compile
compile(produce_node, repair_node, validator, max_attempts=2, checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())
Source code in pyagenity/prebuilt/agent/guarded.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def compile(
    self,
    produce_node: Callable | tuple[Callable, str],
    repair_node: Callable | tuple[Callable, str],
    validator: Callable[[AgentState], bool],
    max_attempts: int = 2,
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> CompiledGraph:
    # Handle produce_node
    if isinstance(produce_node, tuple):
        produce_func, produce_name = produce_node
        if not callable(produce_func):
            raise ValueError("produce_node[0] must be callable")
    else:
        produce_func = produce_node
        produce_name = "PRODUCE"
        if not callable(produce_func):
            raise ValueError("produce_node must be callable")

    # Handle repair_node
    if isinstance(repair_node, tuple):
        repair_func, repair_name = repair_node
        if not callable(repair_func):
            raise ValueError("repair_node[0] must be callable")
    else:
        repair_func = repair_node
        repair_name = "REPAIR"
        if not callable(repair_func):
            raise ValueError("repair_node must be callable")

    self._graph.add_node(produce_name, produce_func)
    self._graph.add_node(repair_name, repair_func)

    # produce -> END or REPAIR
    condition = _guard_condition_factory(validator, max_attempts)
    self._graph.add_conditional_edges(
        produce_name,
        condition,
        {repair_name: repair_name, END: END},
    )
    # repair -> produce
    self._graph.add_edge(repair_name, produce_name)

    self._graph.set_entry_point(produce_name)

    return self._graph.compile(
        checkpointer=checkpointer,
        store=store,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        callback_manager=callback_manager,
    )