Skip to content
Merged
10 changes: 4 additions & 6 deletions libs/giskard-checks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ from giskard.checks import Scenario

chat = ChatInteraction(session_id="session_123", messages=["hi", "hello"])
check = AdvancedSecurityCheck(name="security_test", threshold=0.7)
scenario = Scenario(name="custom_test", sequence=[chat, check])
scenario = Scenario.from_sequence(chat, check, name="custom_test")

serialized = scenario.model_dump()
restored = Scenario.model_validate(serialized)
Expand Down Expand Up @@ -552,12 +552,10 @@ from giskard.checks import (
Equals # Inherits from `Check`
)

scenario = Scenario(
scenario = Scenario.from_sequence(
Interact(inputs="Hello", outputs=lambda inputs: "Hi"),
Equals(expected_value="Hi", key="trace.last.outputs"),
name="programmatic_scenario",
sequence=[
Interact(inputs="Hello", outputs=lambda inputs: "Hi"),
Equals(expected="Hi", key="trace.last.outputs"),
]
)

result = await scenario.run()
Expand Down
2 changes: 2 additions & 0 deletions libs/giskard-checks/src/giskard/checks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Metric,
Scenario,
ScenarioResult,
Step,
SuiteResult,
TestCase,
TestCaseResult,
Expand Down Expand Up @@ -73,6 +74,7 @@
"Metric",
"Scenario",
"ScenarioResult",
"Step",
"SuiteResult",
"TestCase",
"TestCaseResult",
Expand Down
3 changes: 2 additions & 1 deletion libs/giskard-checks/src/giskard/checks/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
SuiteResult,
TestCaseResult,
)
from .scenario import Scenario
from .scenario import Scenario, Step
from .testcase import TestCase

__all__ = [
"Scenario",
"Step",
"Trace",
"InteractionSpec",
"Interact",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,17 +161,15 @@ async def input_generator(trace: Trace) -> AsyncGenerator[str, Trace]:
PrivateAttr()
)

@model_validator(mode="after")
def _validate_injection_mappings(
self,
) -> "Interact[InputType, OutputType, TraceType]":
def _validate_inputs(self) -> None:
try:
self._input_value_generator_provider = ValueGeneratorProvider.from_mapping(
self.inputs, INJECTABLE_TRACE
)
except ValueError as e:
raise ValueError(f"Error getting injection settings for inputs: {e}") from e

def _validate_outputs(self) -> None:
try:
if not isinstance(self.outputs, NotProvided):
self._output_value_provider = ValueProvider.from_mapping(
Expand All @@ -182,6 +180,27 @@ def _validate_injection_mappings(
f"Error getting injection settings for outputs: {e}"
) from e

@model_validator(mode="after")
def _validate_injection_mappings(
self,
) -> "Interact[InputType, OutputType, TraceType]":
self._validate_inputs()
self._validate_outputs()

return self

def set_outputs(
self,
outputs: (
ProviderType[[InputType], OutputType]
| ProviderType[[InputType, TraceType], OutputType]
| NotProvided
),
) -> "Interact[InputType, OutputType, TraceType]":
"""Update the outputs of the interact and recompute the injection mappings. Returns self for chaining."""
self.outputs = outputs
self._validate_outputs()

return self

@override
Expand Down
162 changes: 118 additions & 44 deletions libs/giskard-checks/src/giskard/checks/core/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,54 @@
from .types import GeneratorType, ProviderType


class Step[InputType, OutputType, TraceType: Trace](BaseModel): # pyright: ignore[reportMissingTypeArgument]
"""A scenario step: a sequence of interactions followed by checks.

Each step corresponds to one TestCase at runtime: interactions are applied
to the trace, then checks validate the resulting trace state.
"""

interacts: list[InteractionSpec[InputType, OutputType, TraceType]] = Field(
default_factory=list,
description="Interaction specs to execute in this step.",
)
checks: list[Check[InputType, OutputType, TraceType]] = Field(
default_factory=list,
description="Checks to run after the interactions.",
)


class Scenario[InputType, OutputType, TraceType: Trace](BaseModel): # pyright: ignore[reportMissingTypeArgument]
"""A scenario composed of an ordered sequence of components InteractionSpecs
or Checks with a shared trace.
"""A scenario composed of steps, each containing interacts and checks.

A scenario executes components sequentially, maintaining a shared trace that
A scenario executes steps sequentially, maintaining a shared trace that
accumulates all interactions. Execution stops immediately if any check fails.

Components are processed in order:
- **InteractionSpec** components: Add interactions to the trace
- **Check** components: Validate the current trace state
Each step groups:
- **Interacts** (InteractionSpec): Add interactions to the trace
- **Checks**: Validate the current trace state

Use the fluent API to build a scenario, then call ``run()``:

from giskard.checks import Scenario, Equals
scenario = (
Scenario("multi_step_test")
.interact("Hello", lambda inputs: "Hi")
.check(Equals(expected="Hi", key="trace.last.outputs"))
.check(Equals(expected_value="Hi", key="trace.last.outputs"))
)
result = await scenario.run()

For advanced usage you can instantiate with a pre-filled sequence:
For advanced usage you can instantiate with pre-filled steps or use
``from_sequence()`` for a flat list of components:

from giskard.checks import Scenario, Interact, Equals
scenario = Scenario(
name="multi_step_test",
sequence=[
Interact(inputs="Hello", outputs="Hi"),
Equals(expected="Hi", key="trace.last.outputs"),
steps=[
Step(
interacts=[Interact(inputs="Hello", outputs="Hi")],
checks=[Equals(expected_value="Hi", key="trace.last.outputs")],
),
],
)
result = await scenario.run()
Expand All @@ -47,26 +66,25 @@ class Scenario[InputType, OutputType, TraceType: Trace](BaseModel): # pyright:
----------
name : str
Scenario identifier.
sequence : list[InteractionSpec | Check]
Sequential steps to execute. Each component can be an InteractionSpec or
a Check (which validates the current trace).
steps : list[Step]
Steps to execute. Each step groups interacts and checks.
trace_type : type[TraceType] | None
Optional custom trace type to use. If not provided, the trace type will be
inferred from the sequence of components. Useful when using custom trace
subclasses with additional computed fields or methods.
inferred from the components. Useful when using custom trace subclasses
with additional computed fields or methods.
"""

name: str = Field(
default="Unnamed Scenario",
description="Scenario name",
)
sequence: list[
InteractionSpec[InputType, OutputType, TraceType]
| Check[InputType, OutputType, TraceType]
] = Field(default_factory=list, description="Sequential components to execute")
steps: list[Step[InputType, OutputType, TraceType]] = Field(
default_factory=list,
description="Steps to execute. Each step groups interacts and checks.",
)
trace_type: type[TraceType] | None = Field(
default=None,
description="Type of trace to use for the scenario. If not provided, the trace type will be inferred from the sequence of components.",
description="Type of trace to use for the scenario. If not provided, the trace type will be inferred from the components.",
)
annotations: dict[str, Any] = Field(
default_factory=dict,
Expand All @@ -92,6 +110,56 @@ def __init__(
kwargs["name"] = name
super().__init__(**kwargs)

@classmethod
def from_sequence(
cls,
*components: (
InteractionSpec[InputType, OutputType, TraceType]
| Check[InputType, OutputType, TraceType]
),
name: str = "Unnamed Scenario",
trace_type: type[TraceType] | None = None,
annotations: dict[str, Any] | None = None,
target: (
ProviderType[[InputType], OutputType]
| ProviderType[[InputType, TraceType], OutputType]
| NotProvided
) = NOT_PROVIDED,
) -> Self:
"""Create a scenario from a flat sequence of components.

Components are grouped into steps: a new step is created whenever an
InteractionSpec follows a Check.
"""
without_steps = cls(
name=name,
trace_type=trace_type,
annotations=annotations or {},
target=target,
)

return without_steps.extend(*components)

def _append_step(self) -> Step[InputType, OutputType, TraceType]:
"""Append a new step."""
step = Step(interacts=[], checks=[])
self.steps.append(step)
return step

def _last_step(self) -> Step[InputType, OutputType, TraceType]:
"""Return the last step. Create one if none exists."""
if not self.steps:
return self._append_step()

return self.steps[-1]

def _ensure_step_for_interactions(self) -> Step[InputType, OutputType, TraceType]:
"""Return the last step for appending interactions. Create a new step if the last step has checks or if no step exists."""
step = self._last_step()
if step.checks:
return self._append_step()
return step

def interact(
self,
inputs: (
Expand All @@ -106,11 +174,12 @@ def interact(
) = NOT_PROVIDED,
metadata: dict[str, object] | None = None,
) -> Self:
"""Add an interaction to the scenario sequence.
"""Add an interaction to the scenario.

Creates an `Interact` with the provided inputs and outputs and adds
it to the scenario sequence. Supports static values, callables, and generators
just like `Interact`.
it to the current step. Supports static values, callables, and generators
just like `Interact`. If the current step already has checks, a new step
is created first.

Parameters
----------
Expand All @@ -131,32 +200,35 @@ def interact(
outputs=outputs,
metadata=metadata or {},
)
self.sequence.append(interaction)
return self
return self.add_interaction(interaction)

def check(self, check: Check[InputType, OutputType, TraceType]) -> Self:
"""Add a check to the scenario sequence."""
self.sequence.append(check)
"""Add a check to the scenario."""
step = self._last_step()
step.checks.append(check)
return self

def checks(self, *checks: Check[InputType, OutputType, TraceType]) -> Self:
"""Add multiple checks to the scenario sequence."""
self.sequence.extend(checks)
"""Add multiple checks to the scenario."""
step = self._last_step()
step.checks.extend(checks)
return self

def add_interaction(
self,
interaction: InteractionSpec[InputType, OutputType, TraceType],
) -> Self:
"""Add a custom InteractionSpec to the scenario sequence."""
self.sequence.append(interaction)
"""Add a custom InteractionSpec to the scenario."""
step = self._ensure_step_for_interactions()
step.interacts.append(interaction)
return self

def add_interactions(
self, *interactions: InteractionSpec[InputType, OutputType, TraceType]
) -> Self:
"""Add multiple InteractionSpec objects to the scenario sequence."""
self.sequence.extend(interactions)
"""Add multiple InteractionSpec objects to the scenario."""
step = self._ensure_step_for_interactions()
step.interacts.extend(interactions)
return self

def append(
Expand All @@ -166,9 +238,10 @@ def append(
| Check[InputType, OutputType, TraceType]
),
) -> Self:
"""Append any component to the scenario sequence."""
self.sequence.append(component)
return self
"""Append any component to the scenario."""
if isinstance(component, Check):
return self.check(component)
return self.add_interaction(component)

def extend(
self,
Expand All @@ -177,8 +250,9 @@ def extend(
| Check[InputType, OutputType, TraceType]
),
) -> Self:
"""Extend the scenario sequence with multiple components of any type."""
self.sequence.extend(components)
"""Extend the scenario with multiple components of any type."""
for component in components:
self.append(component)
return self

def with_annotations(self, annotations: dict[str, Any]) -> Self:
Expand Down Expand Up @@ -210,16 +284,16 @@ async def run(
) = NOT_PROVIDED,
return_exception: bool = False,
) -> ScenarioResult[InputType, OutputType]:
"""Execute the scenario components sequentially with shared trace.
"""Execute the scenario steps sequentially with shared trace.

Each component is executed in order:
- Interaction components update the shared trace
- Check components validate the current trace and stop execution on failure
Each step is executed in order:
- Interaction specs update the shared trace
- Checks validate the current trace and stop execution on failure

Returns
-------
ScenarioResult
Results from executing the scenario components.
Results from executing the scenario.
"""
# Lazy import to avoid circular dependency
from ..scenarios.runner import get_runner
Expand Down
Loading