Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions libs/giskard-agents/src/giskard/agents/tools/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Callable, Literal, TypeVar

import logfire_api as logfire
from pydantic import BaseModel, Field, create_model
from pydantic import BaseModel, Field, PrivateAttr, TypeAdapter, create_model

from ..context import RunContext
from ..errors.serializable import Error
Expand Down Expand Up @@ -38,6 +38,8 @@ class Tool(BaseModel):
catch: Callable[[Exception], Any] | None = Field(default=None)

run_context_param: str | None = Field(default=None)
_params_model: type[BaseModel] | None = PrivateAttr(default=None)
_return_adapter: TypeAdapter[Any] | None = PrivateAttr(default=None)

@classmethod
def from_callable(
Expand Down Expand Up @@ -96,14 +98,20 @@ def from_callable(
**fields,
)

return cls(
return_annotation = sig.return_annotation

tool_instance = cls(
name=fn.__name__,
description=description,
parameters_schema=model.model_json_schema(),
fn=fn,
run_context_param=run_context_param,
catch=catch,
)
tool_instance._params_model = model
if return_annotation is not inspect.Parameter.empty:
tool_instance._return_adapter = TypeAdapter(return_annotation)
return tool_instance

def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""Call the underlying function without modification.
Expand Down Expand Up @@ -150,7 +158,19 @@ async def run(
The result of calling the function.
"""

# Inject the context if the tool expects it
# Coerce dict arguments into typed objects via the Pydantic params model.
# We use getattr() instead of model_dump() to preserve coerced types
# (e.g. a raw dict becomes a BaseModel instance). Extra keys that are
# not in model_fields are dropped (Pydantic defaults to extra='ignore').
if self._params_model is not None:
validated = self._params_model.model_validate(arguments)
arguments = {
name: getattr(validated, name)
for name in arguments
if name in self._params_model.model_fields
}

# Inject the context after coercion (RunContext is excluded from the model)
if ctx and self.run_context_param:
arguments = arguments.copy()
arguments[self.run_context_param] = ctx
Expand All @@ -169,8 +189,8 @@ async def run(
logfire.error("tool.run.error", error=res)
return str(res)

if isinstance(res, BaseModel):
res = res.model_dump()
if self._return_adapter is not None:
res = self._return_adapter.dump_python(res, mode="json")

return res

Expand Down
174 changes: 174 additions & 0 deletions libs/giskard-agents/tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""Tests for the tools module."""

from datetime import datetime, timezone
from typing import List
from uuid import UUID

import pytest
from giskard import agents
from giskard.agents.tools import Tool, tool
from pydantic import BaseModel


def test_tool_decorator():
Expand Down Expand Up @@ -182,3 +185,174 @@ def get_weather(self, city: str) -> str:
# The original function behavior is not modified
with pytest.raises(ValueError):
weather.get_weather("Paris")


# ---------------------------------------------------------------------------
# GAP-005: Input coercion and output serialization
# ---------------------------------------------------------------------------


class Address(BaseModel):
street: str
city: str


class Person(BaseModel):
name: str
address: Address


class TimestampedRecord(BaseModel):
id: UUID
created_at: datetime
label: str


async def test_tool_run_coerces_basemodel_input():
"""Tool.run() should coerce a dict into a BaseModel instance."""
received = {}

@tool
def process_person(person: Person) -> str:
"""Process a person.

Parameters
----------
person : Person
The person to process.
"""
received["person"] = person
return person.name

result = await process_person.run(
{"person": {"name": "Alice", "address": {"street": "123 Main", "city": "NYC"}}}
)
assert result == "Alice"
assert isinstance(received["person"], Person)
assert isinstance(received["person"].address, Address)


async def test_tool_run_coerces_optional_basemodel_input():
"""Tool.run() should coerce a dict into BaseModel | None."""
received = {}

@tool
def process_optional(person: Person | None = None) -> str:
"""Process an optional person.

Parameters
----------
person : Person | None
The person to process.
"""
received["person"] = person
return person.name if person else "nobody"

result = await process_optional.run(
{"person": {"name": "Bob", "address": {"street": "1 Elm", "city": "LA"}}}
)
assert result == "Bob"
assert isinstance(received["person"], Person)

result = await process_optional.run({"person": None})
assert result == "nobody"
assert received["person"] is None


async def test_tool_run_coerces_list_basemodel_input():
"""Tool.run() should coerce a list of dicts into list[BaseModel]."""
received = {}

@tool
def process_people(people: list[Person]) -> int:
"""Process a list of people.

Parameters
----------
people : list[Person]
People to process.
"""
received["people"] = people
return len(people)

result = await process_people.run(
{
"people": [
{"name": "A", "address": {"street": "1", "city": "X"}},
{"name": "B", "address": {"street": "2", "city": "Y"}},
]
}
)
assert result == 2
assert all(isinstance(p, Person) for p in received["people"])


async def test_tool_run_serializes_basemodel_output_json_safe():
"""Tool.run() should produce JSON-safe output for BaseModel returns."""
ts = datetime(2025, 6, 15, 12, 0, 0, tzinfo=timezone.utc)
uid = UUID("12345678-1234-5678-1234-567812345678")

@tool
def create_record(label: str) -> TimestampedRecord:
"""Create a record.

Parameters
----------
label : str
The label.
"""
return TimestampedRecord(id=uid, created_at=ts, label=label)

result = await create_record.run({"label": "test"})
assert isinstance(result, dict)
assert isinstance(result["id"], str)
assert isinstance(result["created_at"], str)
assert result["label"] == "test"


async def test_tool_run_serializes_list_basemodel_output():
"""Tool.run() should serialize list[BaseModel] to list of JSON-safe dicts."""

@tool
def list_addresses(n: int) -> list[Address]:
"""List addresses.

Parameters
----------
n : int
Count.
"""
return [Address(street=f"St {i}", city=f"City {i}") for i in range(n)]

result = await list_addresses.run({"n": 3})
assert isinstance(result, list)
assert len(result) == 3
assert all(isinstance(item, dict) for item in result)
assert result[0] == {"street": "St 0", "city": "City 0"}


@pytest.mark.parametrize(
"args,expected",
[
pytest.param({"query": "hello", "limit": 5}, "hello:5", id="str-int"),
pytest.param({"query": "x", "limit": 0}, "x:0", id="str-zero"),
],
)
async def test_tool_run_primitive_types_unchanged(args, expected):
"""Primitive types should pass through coercion and serialization unchanged."""

@tool
def search(query: str, limit: int) -> str:
"""Search.

Parameters
----------
query : str
Query text.
limit : int
Max results.
"""
return f"{query}:{limit}"

result = await search.run(args)
assert result == expected