# Copyright AGNTCY Contributors (https://github.com/agntcy)
# SPDX-License-Identifier: Apache-2.0
import json
import os
import tempfile
from pathlib import Path
from typing import Optional, Union
import datamodel_code_generator
import yaml
from openapi_spec_validator import validate
from openapi_spec_validator.readers import read_from_filename
from ..agws_v0 import OASF_EXTENSION_NAME_MANIFEST, AgentManifest
from ..exceptions import ACPDescriptorValidationException
from ..models import (
AgentACPDescriptor,
AgentACPSpec,
StreamingMode,
)
def _convert_acp_spec_schema(schema_name, schema):
return json.loads(
json.dumps(schema).replace(
"#/$defs/", f"#/components/schemas/{schema_name}/$defs/"
)
)
def _gen_oas_thread_runs(acp_spec: AgentACPSpec, spec_dict):
# Manipulate the spec according to the thread capability flag in the acp_spec
if acp_spec.capabilities.threads:
if acp_spec.thread_state:
spec_dict["components"]["schemas"]["ThreadStateSchema"] = (
_convert_acp_spec_schema("ThreadStateSchema", acp_spec.thread_state)
)
# else:
# # No thread schema defined, hence no support to retrieve thread state
# del spec_dict['paths']['/threads/{thread_id}/state']
else:
# Threads are not enabled
if acp_spec.thread_state:
raise ACPDescriptorValidationException(
"Cannot define `thread_state` if `capabilities.threads` is `false`"
)
else:
# Remove all threads paths
spec_dict["tags"] = [
tag for tag in spec_dict["tags"] if tag["name"] != "Threads"
]
spec_dict["paths"] = {
k: v
for k, v in spec_dict["paths"].items()
if not k.startswith("/threads")
}
def _gen_oas_interrupts(acp_spec: AgentACPSpec, spec_dict):
# Manipulate the spec according to the interrupts capability flag in the acp_spec
if acp_spec.capabilities.interrupts:
if not acp_spec.interrupts or len(acp_spec.interrupts) == 0:
raise ACPDescriptorValidationException(
"Missing interrupt definitions with `spec.capabilities.interrupts=true`"
)
# Add the interrupt payload and resume payload types for the schemas declared in the acp_spec
interrupt_payload_schema = spec_dict["components"]["schemas"][
"InterruptPayloadSchema"
] = {
"oneOf": [],
}
resume_payload_schema = spec_dict["components"]["schemas"][
"ResumePayloadSchema"
] = {
"oneOf": [],
}
for interrupt in acp_spec.interrupts:
interrupt_payload_schema_name = (
f"{interrupt.interrupt_type}InterruptPayload"
)
interrupt_payload_schema["oneOf"].append(
{
"type": "object",
"properties": {
interrupt.interrupt_type: {
"$ref": f"#/components/schemas/{interrupt_payload_schema_name}",
},
},
"additionalProperties": False,
"required": [interrupt.interrupt_type],
}
)
spec_dict["components"]["schemas"][interrupt_payload_schema_name] = (
_convert_acp_spec_schema(
interrupt_payload_schema_name, interrupt.interrupt_payload
)
)
resume_payload_schema_name = f"{interrupt.interrupt_type}ResumePayload"
resume_payload_schema["oneOf"].append(
{
"type": "object",
"properties": {
interrupt.interrupt_type: {
"$ref": f"#/components/schemas/{resume_payload_schema_name}",
},
},
"additionalProperties": False,
"required": [interrupt.interrupt_type],
}
)
spec_dict["components"]["schemas"][resume_payload_schema_name] = (
_convert_acp_spec_schema(
resume_payload_schema_name, interrupt.resume_payload
)
)
else:
# Interrupts are not supported
if acp_spec.interrupts and len(acp_spec.interrupts) > 0:
raise ACPDescriptorValidationException(
"Interrupts defined with `spec.capabilities.interrupts=false`"
)
# Remove interrupt support from API
del spec_dict["paths"]["/runs/{run_id}"]["post"]
interrupt_ref = spec_dict["components"]["schemas"]["RunOutput"][
"discriminator"
]["mapping"]["interrupt"]
del spec_dict["components"]["schemas"]["RunOutput"]["discriminator"]["mapping"][
"interrupt"
]
spec_dict["components"]["schemas"]["RunOutput"]["oneOf"] = [
e
for e in spec_dict["components"]["schemas"]["RunOutput"]["oneOf"]
if e["$ref"] != interrupt_ref
]
def _gen_oas_streaming(acp_spec: AgentACPSpec, spec_dict):
# Manipulate the spec according to the streaming capability flag in the acp_spec
streaming_modes = []
if acp_spec.capabilities.streaming:
if acp_spec.capabilities.streaming.custom:
streaming_modes.append(StreamingMode.CUSTOM)
if acp_spec.capabilities.streaming.values:
streaming_modes.append(StreamingMode.VALUES)
# Perform the checks for custom_streaming_update
if (
StreamingMode.CUSTOM not in streaming_modes
and acp_spec.custom_streaming_update
):
raise ACPDescriptorValidationException(
"custom_streaming_update defined with `spec.capabilities.streaming.custom=false`"
)
if (
StreamingMode.CUSTOM in streaming_modes
and not acp_spec.custom_streaming_update
):
raise ACPDescriptorValidationException(
"Missing custom_streaming_update definitions with `spec.capabilities.streaming.custom=true`"
)
if len(streaming_modes) == 0:
# No streaming is supported. Removing streaming method.
del spec_dict["paths"]["/runs/{run_id}/stream"]
# Removing streaming option from RunCreate
del spec_dict["components"]["schemas"]["RunCreate"]["properties"]["stream_mode"]
return
if len(streaming_modes) == 2:
# Nothing to do
return
# If we reach this point only 1 streaming mode is supported, hence we need to restrict the APIs only to accept it and not the other.
assert len(streaming_modes) == 1
supported_mode = streaming_modes[0].value
spec_dict["components"]["schemas"]["StreamingMode"]["enum"] = [supported_mode]
spec_dict["components"]["schemas"]["RunOutputStream"]["properties"]["data"][
"$ref"
] = spec_dict["components"]["schemas"]["RunOutputStream"]["properties"]["data"][
"discriminator"
]["mapping"][supported_mode]
del spec_dict["components"]["schemas"]["RunOutputStream"]["properties"]["data"][
"oneOf"
]
del spec_dict["components"]["schemas"]["RunOutputStream"]["properties"]["data"][
"discriminator"
]
def _gen_oas_callback(acp_spec: AgentACPSpec, spec_dict):
# Manipulate the spec according to the callback capability flag in the acp_spec
if not acp_spec.capabilities.callbacks:
# No streaming is supported. Removing callback option from RunCreate
del spec_dict["components"]["schemas"]["RunCreate"]["properties"]["webhook"]
[docs]
def generate_agent_oapi_for_schemas(specs: AgentACPSpec):
spec_dict = {
"openapi": "3.1.0",
"info": {"title": "Agent Schemas", "version": "0.1.0"},
"components": {"schemas": {}},
}
spec_dict["components"]["schemas"]["InputSchema"] = _convert_acp_spec_schema(
"InputSchema", specs.input
)
spec_dict["components"]["schemas"]["OutputSchema"] = _convert_acp_spec_schema(
"OutputSchema", specs.output
)
spec_dict["components"]["schemas"]["ConfigSchema"] = _convert_acp_spec_schema(
"ConfigSchema", specs.config
)
validate(spec_dict)
return spec_dict
[docs]
def generate_agent_oapi(
agent_source: Union[AgentACPDescriptor, AgentManifest],
spec_path: Optional[str] = None,
):
if spec_path is None:
spec_path = os.getenv("ACP_SPEC_PATH", "acp-spec/openapi.json")
spec_dict, _ = read_from_filename(spec_path)
# If no exception is raised by validate(), the spec is valid.
validate(spec_dict)
agent_spec = None
if isinstance(agent_source, AgentACPDescriptor):
agent_spec = agent_source.specs
agent_name = agent_source.metadata.ref.name
elif isinstance(agent_source, AgentManifest):
for ext in agent_source.extensions:
if ext.name == OASF_EXTENSION_NAME_MANIFEST:
agent_spec = ext.data.acp
agent_name = agent_source.name
else:
raise ValueError("unknown object type for agent_source")
spec_dict["info"]["title"] = f"ACP Spec for {agent_name}"
spec_dict["components"]["schemas"]["InputSchema"] = _convert_acp_spec_schema(
"InputSchema", agent_spec.input
)
spec_dict["components"]["schemas"]["OutputSchema"] = _convert_acp_spec_schema(
"OutputSchema", agent_spec.output
)
spec_dict["components"]["schemas"]["ConfigSchema"] = _convert_acp_spec_schema(
"ConfigSchema", agent_spec.config
)
_gen_oas_thread_runs(agent_spec, spec_dict)
_gen_oas_interrupts(agent_spec, spec_dict)
_gen_oas_streaming(agent_spec, spec_dict)
_gen_oas_callback(agent_spec, spec_dict)
validate(spec_dict)
return spec_dict
[docs]
def generate_agent_models(
agent_source: Union[AgentACPDescriptor, AgentManifest],
path: str,
model_file_name: str = "models.py",
):
agent_spec = None
if isinstance(agent_source, AgentACPDescriptor):
agent_spec = generate_agent_oapi_for_schemas(agent_source.specs)
agent_name = agent_source.metadata.ref.name
elif isinstance(agent_source, AgentManifest):
for ext in agent_source.extensions:
if ext.name == OASF_EXTENSION_NAME_MANIFEST:
agent_spec = generate_agent_oapi_for_schemas(ext.data.acp)
agent_name = agent_source.name
else:
raise ValueError("unknown object type for agent_source")
if agent_spec is None:
raise ValueError("cannot find agent data")
agent_sdk_path = path
agent_models_dir = agent_sdk_path
tmp_dir = tempfile.TemporaryDirectory()
specpath = os.path.join(tmp_dir.name, "openapi.yaml")
modelspath = os.path.join(agent_models_dir, model_file_name)
os.makedirs(agent_models_dir, exist_ok=True)
with open(specpath, "w") as file:
yaml.dump(agent_spec, file, default_flow_style=False)
datamodel_code_generator.generate(
json.dumps(agent_spec),
input_filename=specpath,
input_file_type=datamodel_code_generator.InputFileType.OpenAPI,
output_model_type=datamodel_code_generator.DataModelType.PydanticV2BaseModel,
output=Path(modelspath),
disable_timestamp=True,
custom_file_header=f"# Generated from ACP Descriptor {agent_name} using datamodel_code_generator.",
keep_model_order=True,
use_double_quotes=True, # match ruff formatting
)