Source code for agntcy_acp.langgraph.acp_node

# Copyright AGNTCY Contributors (https://github.com/agntcy)
# SPDX-License-Identifier: Apache-2.0
import logging
from collections.abc import MutableMapping
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union

from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamMode as LangGraphStreamMode
from langgraph.types import (
    interrupt,
)
from langgraph.utils.runnable import RunnableCallable
from pydantic import BaseModel

from agntcy_acp import (
    ACPClient,
    ApiClientConfiguration,
    AsyncACPClient,
)
from agntcy_acp.exceptions import ACPRunException
from agntcy_acp.models import (
    Config,
    RunCreateStateful,
    RunCreateStateless,
    RunError,
    RunInterrupt,
    RunOutput,
    RunResult,
    RunWaitResponseStateful,
    RunWaitResponseStateless,
    StreamingMode,
)

logger = logging.getLogger(__name__)


def _extract_element(container: Any, path: str) -> Any:
    element = container
    for path_el in path.split("."):
        element = (
            element.get(path_el)
            if isinstance(element, MutableMapping)
            else getattr(element, path_el)
        )

    if element is None:
        raise Exception(f"Unable to extract {path} from state {container}")

    return element


[docs] class ACPNode: """This class represents a Langgraph Node that holds a remote connection to an ACP Agent It can be instantiated and added to any langgraph graph. my_node = ACPNode(...) sg = StateGraph(GraphState) sg.add_node(my_node) """ def __init__( self, name: str, agent_id: str, client_config: ApiClientConfiguration, input_path: str, input_type, output_path: str, output_type, config_path: Optional[str] = None, config_type=None, auth_header: Optional[Dict] = None, use_threads: bool = False, ): """Instantiate a Langgraph node encapsulating a remote ACP agent :param name: Name of the langgraph node :param agent_id: Agent ID in the remote server :param client_config: Configuration of the ACP Client :param input_path: Dot-separated path of the ACP Agent input in the graph overall state :param input_type: Pydantic class defining the schema of the ACP Agent input :param output_path: Dot-separated path of the ACP Agent output in the graph overall state :param output_type: Pydantic class defining the schema of the ACP Agent output :param config_path: Dot-separated path of the ACP Agent config in the graph configurable :param config_type: Pydantic class defining the schema of the ACP Agent config :param auth_header: A dictionary containing auth details necessary to communicate with the node """ self.__name__ = name self.agent_id = agent_id self.clientConfig = client_config self.inputPath = input_path self.inputType = input_type self.outputPath = output_path self.outputType = output_type self.configPath = config_path self.configType = config_type self.auth_header = auth_header self.use_threads = use_threads self.previous_thread_run = {}
[docs] def get_name(self): return self.__name__
def _extract_input(self, state: Any) -> Any: if not state: return state try: if self.inputPath: state = _extract_element(state, self.inputPath) except Exception as e: raise Exception( f"ERROR in ACP Node {self.get_name()}. Unable to extract input: {e}" ) if isinstance(state, BaseModel): return state.model_dump() elif isinstance(state, MutableMapping): return state else: return {} def _extract_config(self, config: Optional[RunnableConfig]) -> Any: if not config: return config try: if not self.configPath: config = {} else: if "configurable" not in config: logger.error( f'ACP Node {self.get_name()}. Unable to extract config: missing key "configurable" in RunnableConfig' ) return None config = _extract_element(config["configurable"], self.configPath) except Exception as e: logger.info(f"ACP Node {self.get_name()}. Unable to extract config: {e}") return None if self.configType is not None: # Set defaults, etc. agent_config = self.configType.model_validate(config) else: agent_config = config if isinstance(agent_config, BaseModel): return agent_config.model_dump() elif isinstance(agent_config, MutableMapping): return agent_config else: return {} def _set_output(self, state: Any, output: Optional[Dict[str, Any]]): output_parent = state output_state = self.outputType.model_validate(output) for el in self.outputPath.split(".")[:-1]: if isinstance(output_parent, MutableMapping): output_parent = output_parent[el] elif hasattr(output_parent, el): output_parent = getattr(output_parent, el) else: raise ValueError("object missing attribute: {el}") el = self.outputPath.split(".")[-1] if isinstance(output_parent, MutableMapping): output_parent[el] = output_state elif hasattr(output_parent, el): setattr(output_parent, el, output_state) else: raise ValueError("object missing attribute: {el}") def _handle_run_output(self, state: Any, run_output: RunOutput): if isinstance(run_output.actual_instance, RunResult): run_result: RunResult = run_output.actual_instance self._set_output(state, run_result.values) elif isinstance(run_output.actual_instance, RunError): run_error: RunError = run_output.actual_instance raise ACPRunException(f"Run Failed: {run_error}") elif isinstance(run_output.actual_instance, RunInterrupt): # This case is handled in the invokes pass else: raise ACPRunException( f"ACP Server returned a unsupporteed response: {run_output}" ) return state def _generate_thread_id(self) -> str: return "abc" def _get_thread_id(self, config: RunnableConfig) -> Optional[str]: configurable = config.get("configurable", {}) thread_id = configurable.get("thread_id", None) return thread_id def _handle_interrupt( self, run_response: Union[RunWaitResponseStateless, RunWaitResponseStateful], thread_id: Optional[str], ): assert run_response.output is not None, "No output found for interrupt" if thread_id is None: raise ValueError("Thread_id is required for a interrupted run") first_key = next(iter(run_response.output.actual_instance.interrupt)) value = run_response.output.actual_instance.interrupt[first_key] previous_thread_run = self.previous_thread_run.get(thread_id, None) if previous_thread_run is None: self.previous_thread_run[thread_id] = run_response interrupt_result = interrupt(value) del self.previous_thread_run[thread_id] return interrupt_result
[docs] def invoke( self, state: dict[str, Any] | Any, config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Any: with ACPClient(configuration=self.clientConfig) as acp_client: thread_id = self._get_thread_id(config) if thread_id is not None and thread_id in self.previous_thread_run: run_response = self.previous_thread_run[thread_id] elif not self.use_threads: run_create = RunCreateStateless( agent_id=self.agent_id, input=self._extract_input(state), config=Config(configurable=self._extract_config(config)), ) run_response = acp_client.create_and_wait_for_stateless_run_output( run_create ) else: if thread_id is None: thread_id = self._generate_thread_id() run_create = RunCreateStateful( agent_id=self.agent_id, input=self._extract_input(state), config=Config(configurable=self._extract_config(config)), ) run_response = acp_client.create_and_wait_for_thread_run_output( thread_id, run_create ) run_output = run_response.output if run_output is None: # The spec does not require an output so count it as success. return None while isinstance(run_output.actual_instance, RunInterrupt): interrupt_result = self._handle_interrupt( run_response=run_response, thread_id=thread_id ) if not self.use_threads: resume_run = acp_client.resume_stateless_run( run_id=run_response.run.run_id, body=interrupt_result ) run_response = acp_client.wait_for_stateless_run_output( run_id=resume_run.run_id ) else: resume_run = acp_client.resume_thread_run( thread_id=thread_id, run_id=run_response.run.run_id, body=interrupt_result, ) run_response = acp_client.wait_for_thread_run_output( thread_id=thread_id, run_id=resume_run.run_id, ) run_output = run_response.output if run_output is None: # The spec does not require an output so count it as success. return None # output is the same between stateful and stateless self._handle_run_output(state, run_output) return state
[docs] async def ainvoke( self, state: dict[str, Any] | Any, config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Any: async with AsyncACPClient(configuration=self.clientConfig) as acp_client: thread_id = self._get_thread_id(config) if thread_id is not None and thread_id in self.previous_thread_run: run_response = self.previous_thread_run[thread_id] elif not self.use_threads: run_create = RunCreateStateless( agent_id=self.agent_id, input=self._extract_input(state), config=Config(configurable=self._extract_config(config)), ) run_response = ( await acp_client.create_and_wait_for_stateless_run_output( run_create ) ) else: if thread_id is None: thread_id = self._generate_thread_id() run_create = RunCreateStateful( agent_id=self.agent_id, input=self._extract_input(state), config=Config(configurable=self._extract_config(config)), ) run_response = await acp_client.create_and_wait_for_thread_run_output( thread_id, run_create ) run_output = run_response.output if run_output is None: # The spec does not require an output so count it as success. return None while isinstance(run_output.actual_instance, RunInterrupt): interrupt_result = self._handle_interrupt( run_response=run_response, thread_id=thread_id ) if not self.use_threads: resume_run = await acp_client.resume_stateless_run( run_id=run_response.run.run_id, body=interrupt_result, ) run_response = await acp_client.wait_for_stateless_run_output( run_id=resume_run.run_id, ) else: resume_run = await acp_client.resume_thread_run( thread_id=thread_id, run_id=run_response.run.run_id, body=interrupt_result, ) run_response = await acp_client.wait_for_thread_run_output( thread_id=thread_id, run_id=resume_run.run_id, ) run_output = run_response.output if run_output is None: # The spec does not require an output so count it as success. return None state = self._handle_run_output(state, run_output)
def _convert_stream_mode( self, stream_mode: Union[LangGraphStreamMode, List[LangGraphStreamMode], None], ) -> Union[StreamingMode, List[StreamingMode], None]: valid_acp_modes = [ member.value for name, member in StreamingMode.__members__.items() ] if stream_mode is None: return None elif isinstance(stream_mode, List): new_modes = [elem for elem in stream_mode if elem in valid_acp_modes] if not new_modes and stream_mode: raise ValueError("unsupported stream modes: {stream_mode}") return new_modes elif stream_mode in valid_acp_modes: return StreamingMode(stream_mode) else: raise ValueError(f"unsupported stream mode: {stream_mode}")
[docs] def stream( self, input: Union[Dict[str, Any], Any], config: Optional[RunnableConfig] = None, *, stream_mode: Union[LangGraphStreamMode, List[LangGraphStreamMode], None] = None, **kwargs, ) -> Iterator[Dict[str, Any] | Any]: """Stream graph steps for a single input.""" # TODO: add support for interrupts thread_id = self._get_thread_id(config) if self.use_threads: if thread_id is None: thread_id = self._generate_thread_id() run_create = RunCreateStateful( agent_id=self.agent_id, input=self._extract_input(input), config=Config(configurable=self._extract_config(config)), stream_mode=self._convert_stream_mode(stream_mode), ) with ACPClient(configuration=self.clientConfig) as acp_client: for run_output in acp_client.create_and_stream_thread_run_output( thread_id, run_create ): self._handle_run_output(input, run_output.output) yield input else: run_create = RunCreateStateless( agent_id=self.agent_id, input=self._extract_input(input), config=Config(configurable=self._extract_config(config)), stream_mode=self._convert_stream_mode(stream_mode), ) with ACPClient(configuration=self.clientConfig) as acp_client: for run_output in acp_client.create_and_stream_stateless_run_output( run_create ): self._handle_run_output(input, run_output.output) yield input
[docs] async def astream( self, input: Union[Dict[str, Any], Any], config: Optional[RunnableConfig] = None, *, stream_mode: Union[LangGraphStreamMode, List[LangGraphStreamMode], None] = None, **kwargs, ) -> AsyncIterator[Dict[str, Any] | Any]: """Stream graph steps for a single input.""" # TODO: add support for interrupts thread_id = self._get_thread_id(config) if self.use_threads: if thread_id is None: thread_id = self._generate_thread_id() run_create = RunCreateStateful( agent_id=self.agent_id, input=self._extract_input(input), config=Config(configurable=self._extract_config(config)), stream_mode=self._convert_stream_mode(stream_mode), ) async with AsyncACPClient(configuration=self.clientConfig) as acp_client: async for run_output in acp_client.create_and_stream_thread_run_output( thread_id, run_create ): self._handle_run_output(input, run_output.output) yield input else: run_create = RunCreateStateless( agent_id=self.agent_id, input=self._extract_input(input), config=Config(configurable=self._extract_config(config)), stream_mode=self._convert_stream_mode(stream_mode), ) async with AsyncACPClient(configuration=self.clientConfig) as acp_client: async for ( run_output ) in acp_client.create_and_stream_stateless_run_output(run_create): self._handle_run_output(input, run_output.output) yield input
def __call__(self, state, config): return RunnableCallable(self.invoke, self.ainvoke)