Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/giskard-checks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ scenario2 = (
)

# Create a suite with a shared target
target_sut = lambda x: f"Echo: {x}"
target_sut = lambda inputs: f"Echo: {inputs}"
suite = Suite(name="my_suite", target=target_sut)

# Add scenarios
Expand Down
35 changes: 10 additions & 25 deletions libs/giskard-checks/src/giskard/checks/core/interaction/interact.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,16 @@
from collections.abc import AsyncGenerator
from typing import Any, cast, override

from giskard.checks.utils.injectable import ValueGenerator, ValueProvider
from giskard.core.utils import NOT_PROVIDED, NotProvided
from pydantic import Field, PrivateAttr, model_validator

from ...utils.parameter_injection import ParameterInjectionRequirement
from ...utils.value_provider import (
ValueGeneratorProvider,
ValueProvider,
)
from ..input_generator import InputGenerator
from ..types import GeneratorType, ProviderType
from .base import InteractionSpec
from .interaction import Interaction
from .trace import Trace

INJECTABLE_TRACE = ParameterInjectionRequirement(
class_info=Trace,
optional=True,
)

INJECTABLE_INPUT = ParameterInjectionRequirement(
class_info=Any,
optional=True,
)


@InteractionSpec.register("interact")
class Interact[InputType, OutputType, TraceType: Trace]( # pyright: ignore[reportMissingTypeArgument]
Expand Down Expand Up @@ -154,26 +140,25 @@ async def input_generator(trace: Trace) -> AsyncGenerator[str, Trace]:
default_factory=dict, description="The metadata of the interaction."
)

_input_value_generator_provider: ValueGeneratorProvider[
[TraceType], InputType, TraceType
] = PrivateAttr()
_output_value_provider: ValueProvider[[InputType, TraceType], OutputType] = (
_input_value_generator_provider: ValueGenerator[..., InputType, TraceType] = (
PrivateAttr()
)
_output_injectable: ValueProvider[..., OutputType] = PrivateAttr()

def _validate_inputs(self) -> None:
try:
self._input_value_generator_provider = ValueGeneratorProvider.from_mapping(
self.inputs, INJECTABLE_TRACE
self._input_value_generator_provider = cast(
ValueGenerator[[TraceType], InputType, TraceType],
ValueGenerator(self.inputs, {"trace"}),
)
except ValueError as e:
raise ValueError(f"Error getting injection settings for inputs: {e}") from e

def _validate_outputs(self) -> None:
try:
if not isinstance(self.outputs, NotProvided):
self._output_value_provider = ValueProvider.from_mapping(
self.outputs, INJECTABLE_INPUT, INJECTABLE_TRACE
self._output_injectable = ValueProvider(
self.outputs, {"inputs", "trace"}
)
except ValueError as e:
raise ValueError(
Expand Down Expand Up @@ -207,7 +192,7 @@ def set_outputs(
async def generate(
self, trace: TraceType
) -> AsyncGenerator[Interaction[InputType, OutputType], TraceType]:
generator = await self._input_value_generator_provider(trace)
generator = await self._input_value_generator_provider(trace=trace)

try:
inputs = await anext(generator)
Expand All @@ -218,7 +203,7 @@ async def generate(
)
# Execute user-provided logic to transform inputs into either raw outputs
# or a fully constructed Interaction instance.
outputs = await self._output_value_provider(inputs, trace)
outputs = await self._output_injectable(inputs=inputs, trace=trace)
# Yield the interaction back to the caller and wait for an updated trace
# that captures the evaluation of this iteration.
trace = yield self._get_interaction(
Expand Down
89 changes: 89 additions & 0 deletions libs/giskard-checks/src/giskard/checks/utils/injectable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import inspect
from collections.abc import AsyncGenerator
from typing import cast

from giskard.checks.core.types import GeneratorType, ProviderType, SyncOrAsyncGenerator
from giskard.checks.utils.generator import a_generator


def _validate_kwargs_keys[R](
value_or_callable: ProviderType[..., R], kwargs_keys: set[str]
) -> tuple[list[str], set[str]]:
if not callable(value_or_callable):
return ([], set())

signature = inspect.signature(value_or_callable)
injected_positional_only_names: list[str] = []
injected_kwarg_names: set[str] = set()
for param in signature.parameters.values():
if param.kind in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
):
continue
if param.name in kwargs_keys:
if param.kind is inspect.Parameter.POSITIONAL_ONLY:
injected_positional_only_names.append(param.name)
else:
injected_kwarg_names.add(param.name)
else:
default = cast(object, param.default)
if default is inspect.Parameter.empty:
raise TypeError(
f"Parameter '{param.name}' is required but not in the injection requirements."
)

return injected_positional_only_names, injected_kwarg_names


class ValueProvider[**P, R]:
_value_or_callable: ProviderType[..., R]
_injected_positional_only_names: list[str]
_injected_kwarg_names: set[str]

def __init__(self, value_or_callable: ProviderType[..., R], kwargs_keys: set[str]):
self._value_or_callable = value_or_callable
(
self._injected_positional_only_names,
self._injected_kwarg_names,
) = _validate_kwargs_keys(value_or_callable, kwargs_keys)

async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
# This is a static value
if not callable(self._value_or_callable):
return self._value_or_callable

injected_positional_only_args = [
kwargs[name]
for name in self._injected_positional_only_names
if name in kwargs
]
injected_kwargs = {
key: kwargs[key] for key in self._injected_kwarg_names if key in kwargs
}

result = self._value_or_callable(
*injected_positional_only_args, **injected_kwargs
)

# Handle Awaitables
if inspect.isawaitable(result):
return await result

return cast(R, result)


class ValueGenerator[**P, R, S]:
_value_provider: ValueProvider[P, R | SyncOrAsyncGenerator[R, S]]

def __init__(
self, value_or_callable: GeneratorType[..., R, S], kwargs_keys: set[str]
):
self._value_provider = cast(
ValueProvider[P, R | SyncOrAsyncGenerator[R, S]],
ValueProvider(value_or_callable, kwargs_keys),
)

async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> AsyncGenerator[R, S]:
value_or_generator = await self._value_provider(*args, **kwargs)
return a_generator(value_or_generator)
156 changes: 0 additions & 156 deletions libs/giskard-checks/src/giskard/checks/utils/parameter_injection.py

This file was deleted.

Loading