Skip to content

Branch join

Classes:

Name Description
BranchJoinAgent

Execute multiple branches then join.

Attributes:

Name Type Description
StateT

Attributes

StateT module-attribute

StateT = TypeVar('StateT', bound=AgentState)

Classes

BranchJoinAgent

Execute multiple branches then join.

Note: This prebuilt models branches sequentially (not true parallel execution). For each provided branch node, we add edges branch_i -> JOIN. The JOIN node decides whether more branches remain or END. A more advanced version could use BackgroundTaskManager for concurrency.

Methods:

Name Description
__init__
compile
Source code in pyagenity/prebuilt/agent/branch_join.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
class BranchJoinAgent[StateT: AgentState]:
    """Execute multiple branches then join.

    Note: This prebuilt models branches sequentially (not true parallel execution).
    For each provided branch node, we add edges branch_i -> JOIN. The JOIN node
    decides whether more branches remain or END. A more advanced version could
    use BackgroundTaskManager for concurrency.
    """

    def __init__(
        self,
        state: StateT | None = None,
        context_manager: BaseContextManager[StateT] | None = None,
        publisher: BasePublisher | None = None,
        id_generator: BaseIDGenerator = DefaultIDGenerator(),
        container: InjectQ | None = None,
    ):
        self._graph = StateGraph[StateT](
            state=state,
            context_manager=context_manager,
            publisher=publisher,
            id_generator=id_generator,
            container=container,
        )

    def compile(
        self,
        branches: dict[str, Callable | tuple[Callable, str]],
        join_node: Callable | tuple[Callable, str],
        next_branch_condition: Callable | None = None,
        checkpointer: BaseCheckpointer[StateT] | None = None,
        store: BaseStore | None = None,
        interrupt_before: list[str] | None = None,
        interrupt_after: list[str] | None = None,
        callback_manager: CallbackManager = CallbackManager(),
    ) -> CompiledGraph:
        if not branches:
            raise ValueError("branches must be a non-empty dict of name -> callable/tuple")

        # Add branch nodes
        branch_names = []
        for key, fn in branches.items():
            if isinstance(fn, tuple):
                branch_func, branch_name = fn
                if not callable(branch_func):
                    raise ValueError(f"Branch '{key}'[0] must be callable")
            else:
                branch_func = fn
                branch_name = key
                if not callable(branch_func):
                    raise ValueError(f"Branch '{key}' must be callable")
            self._graph.add_node(branch_name, branch_func)
            branch_names.append(branch_name)

        # Handle join_node
        if isinstance(join_node, tuple):
            join_func, join_name = join_node
            if not callable(join_func):
                raise ValueError("join_node[0] must be callable")
        else:
            join_func = join_node
            join_name = "JOIN"
            if not callable(join_func):
                raise ValueError("join_node must be callable")
        self._graph.add_node(join_name, join_func)

        # Wire branches to JOIN
        for name in branch_names:
            self._graph.add_edge(name, join_name)

        # Entry: first branch
        first = branch_names[0]
        self._graph.set_entry_point(first)

        # Decide next branch or END after join
        if next_branch_condition is None:
            # default: END after join
            def _cond(_: AgentState) -> str:
                return END

            next_branch_condition = _cond

        # next_branch_condition returns a branch name or END
        path_map = {k: k for k in branch_names}
        path_map[END] = END
        self._graph.add_conditional_edges(join_name, next_branch_condition, path_map)

        return self._graph.compile(
            checkpointer=checkpointer,
            store=store,
            interrupt_before=interrupt_before,
            interrupt_after=interrupt_after,
            callback_manager=callback_manager,
        )

Functions

__init__
__init__(state=None, context_manager=None, publisher=None, id_generator=DefaultIDGenerator(), container=None)
Source code in pyagenity/prebuilt/agent/branch_join.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def __init__(
    self,
    state: StateT | None = None,
    context_manager: BaseContextManager[StateT] | None = None,
    publisher: BasePublisher | None = None,
    id_generator: BaseIDGenerator = DefaultIDGenerator(),
    container: InjectQ | None = None,
):
    self._graph = StateGraph[StateT](
        state=state,
        context_manager=context_manager,
        publisher=publisher,
        id_generator=id_generator,
        container=container,
    )
compile
compile(branches, join_node, next_branch_condition=None, checkpointer=None, store=None, interrupt_before=None, interrupt_after=None, callback_manager=CallbackManager())
Source code in pyagenity/prebuilt/agent/branch_join.py
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def compile(
    self,
    branches: dict[str, Callable | tuple[Callable, str]],
    join_node: Callable | tuple[Callable, str],
    next_branch_condition: Callable | None = None,
    checkpointer: BaseCheckpointer[StateT] | None = None,
    store: BaseStore | None = None,
    interrupt_before: list[str] | None = None,
    interrupt_after: list[str] | None = None,
    callback_manager: CallbackManager = CallbackManager(),
) -> CompiledGraph:
    if not branches:
        raise ValueError("branches must be a non-empty dict of name -> callable/tuple")

    # Add branch nodes
    branch_names = []
    for key, fn in branches.items():
        if isinstance(fn, tuple):
            branch_func, branch_name = fn
            if not callable(branch_func):
                raise ValueError(f"Branch '{key}'[0] must be callable")
        else:
            branch_func = fn
            branch_name = key
            if not callable(branch_func):
                raise ValueError(f"Branch '{key}' must be callable")
        self._graph.add_node(branch_name, branch_func)
        branch_names.append(branch_name)

    # Handle join_node
    if isinstance(join_node, tuple):
        join_func, join_name = join_node
        if not callable(join_func):
            raise ValueError("join_node[0] must be callable")
    else:
        join_func = join_node
        join_name = "JOIN"
        if not callable(join_func):
            raise ValueError("join_node must be callable")
    self._graph.add_node(join_name, join_func)

    # Wire branches to JOIN
    for name in branch_names:
        self._graph.add_edge(name, join_name)

    # Entry: first branch
    first = branch_names[0]
    self._graph.set_entry_point(first)

    # Decide next branch or END after join
    if next_branch_condition is None:
        # default: END after join
        def _cond(_: AgentState) -> str:
            return END

        next_branch_condition = _cond

    # next_branch_condition returns a branch name or END
    path_map = {k: k for k in branch_names}
    path_map[END] = END
    self._graph.add_conditional_edges(join_name, next_branch_condition, path_map)

    return self._graph.compile(
        checkpointer=checkpointer,
        store=store,
        interrupt_before=interrupt_before,
        interrupt_after=interrupt_after,
        callback_manager=callback_manager,
    )