Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 31 additions & 34 deletions libs/giskard-core/src/giskard/core/rate_limiter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import operator
import os
import threading
import uuid
import warnings
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -35,11 +34,9 @@ class RateLimiterRegistry:
throttling is consistent across serialization boundaries.
"""

_lock: threading.Lock
_instances: dict[str, WeakSet["BaseRateLimiter[Any]"]]

def __init__(self):
self._lock = threading.Lock()
self._instances = {}

def register_instance(self, rate_limiter: "BaseRateLimiter") -> None:
Expand All @@ -53,35 +50,36 @@ def register_instance(self, rate_limiter: "BaseRateLimiter") -> None:
rate_limiter : BaseRateLimiter
The rate limiter instance to register.
"""
with self._lock:
instances = self._instances.get(rate_limiter.id)
if instances is None:
instances = WeakSet["BaseRateLimiter"]()
self._instances[rate_limiter.id] = instances

all_instances = list(instances)
matching_instances = [
instance for instance in all_instances if instance == rate_limiter
]

match = matching_instances[0] if matching_instances else None
rate_limiter.initialize_state(match)

if (
not GISKARD_DISABLE_DUPLICATE_RATE_LIMITERS_WARNINGS
and not match
and all_instances
):
warnings.warn(
instances = self._instances.get(rate_limiter.id)
if instances is None:
instances = WeakSet["BaseRateLimiter"]()
self._instances[rate_limiter.id] = instances

all_instances = list(instances)
matching_instances = [
instance for instance in all_instances if instance == rate_limiter
]

match = matching_instances[0] if matching_instances else None
rate_limiter.initialize_state(match)

if not match and all_instances:
if not GISKARD_DISABLE_DUPLICATE_RATE_LIMITERS_WARNINGS:
raise ValueError(
(
f"Rate limiter with id '{rate_limiter.id}' already registered, "
f"this will make BaseRateLimiter.from_id('{rate_limiter.id}') unreliable. "
"Set GISKARD_DISABLE_DUPLICATE_RATE_LIMITERS_WARNINGS=1 to disable this warning"
),
RuntimeWarning,
f"Rate limiter with id '{rate_limiter.id}' already registered. "
"Set GISKARD_DISABLE_DUPLICATE_RATE_LIMITERS_WARNINGS=1 to downgrade this error to a warning"
)
)
warnings.warn(
(
f"Rate limiter with id '{rate_limiter.id}' already registered, "
f"this will make BaseRateLimiter.from_id('{rate_limiter.id}') unreliable. "
),
RuntimeWarning,
)

instances.add(rate_limiter)
instances.add(rate_limiter)

def get_instance(self, id: str) -> "BaseRateLimiter[Any]":
"""Retrieve a registered rate limiter by id.
Expand All @@ -101,12 +99,11 @@ def get_instance(self, id: str) -> "BaseRateLimiter[Any]":
ValueError
If no rate limiter with the given id is registered.
"""
with self._lock:
instances = self._instances.get(id)
if not instances:
raise ValueError(f"Rate limiter with id '{id}' not found")
instances = self._instances.get(id)
if not instances:
raise ValueError(f"Rate limiter with id '{id}' not found")

return next(iter(instances))
return next(iter(instances))


@discriminated_base
Expand Down
15 changes: 6 additions & 9 deletions libs/giskard-core/src/giskard/core/rate_limiter/min_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,15 @@


class _MinIntervalRateLimiterState:
"""Internal state for MinIntervalRateLimiter: semaphore, lock, and next allowed request time."""
"""Internal state for MinIntervalRateLimiter: semaphore and next allowed request time."""

semaphore: asyncio.Semaphore | None
lock: asyncio.Lock
next_request_time: float

def __init__(self, max_concurrent: int | None):
self.semaphore = (
asyncio.Semaphore(max_concurrent) if max_concurrent is not None else None
)
self.lock = asyncio.Lock()
self.next_request_time = time.monotonic()


Expand All @@ -45,12 +43,11 @@ async def throttle(self) -> AsyncGenerator[float]:
"""Wait for rate limit, then yields the time waited in seconds."""
start_time = time.monotonic()
async with self._state.semaphore or nullcontext():
async with self._state.lock:
current_time = time.monotonic()
wait_time = self._state.next_request_time - current_time
self._state.next_request_time = (
max(self._state.next_request_time, current_time) + self.min_interval
)
current_time = time.monotonic()
wait_time = self._state.next_request_time - current_time
self._state.next_request_time = (
max(self._state.next_request_time, current_time) + self.min_interval
)

if wait_time > 0:
await asyncio.sleep(wait_time)
Expand Down
47 changes: 18 additions & 29 deletions libs/giskard-core/tests/test_rate_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import pytest
from giskard.core import BaseRateLimiter, MinIntervalRateLimiter
from pydantic import ValidationError

JITTER_TIME = 0.02 # 20ms jitter

Expand All @@ -27,34 +28,23 @@ async def throttle(self) -> AsyncGenerator[float]:


class TestRateLimiterRegistry:
def test_warns_when_creating_rate_limiter_with_duplicate_id(self):
def test_raises_when_creating_rate_limiter_with_duplicate_id(self):
rl_id = _uid()
with pytest.warns(
RuntimeWarning,
match=f"Rate limiter with id '{rl_id}' already registered",
):
_rate_limiter_a = MinIntervalRateLimiter.from_rpm(60, id=rl_id)
_rate_limiter_b = MinIntervalRateLimiter.from_rpm(120, id=rl_id)
_rate_limiter_a = MinIntervalRateLimiter.from_rpm(60, id=rl_id)
with pytest.raises(ValidationError, match="already registered"):
MinIntervalRateLimiter.from_rpm(120, id=rl_id)

rl_id = _uid()
with pytest.warns(
RuntimeWarning,
match=f"Rate limiter with id '{rl_id}' already registered",
):
_rate_limiter_a = MinIntervalRateLimiter.from_rpm(rpm=60, id=rl_id)
_rate_limiter_b = MinIntervalRateLimiter.from_rpm(
rpm=60, max_concurrent=1, id=rl_id
)
_rate_limiter_a = MinIntervalRateLimiter.from_rpm(rpm=60, id=rl_id)
with pytest.raises(ValidationError, match="already registered"):
MinIntervalRateLimiter.from_rpm(rpm=60, max_concurrent=1, id=rl_id)

rl_id = _uid()
with pytest.warns(
RuntimeWarning,
match=f"Rate limiter with id '{rl_id}' already registered",
):
_rate_limiter_a = MinIntervalRateLimiter.from_rpm(rpm=60, id=rl_id)
_rate_limiter_b = CustomRateLimiter(id=rl_id)
_rate_limiter_a = MinIntervalRateLimiter.from_rpm(rpm=60, id=rl_id)
with pytest.raises(ValidationError, match="already registered"):
CustomRateLimiter(id=rl_id)

def test_does_not_warn_when_disabled_and_creating_rate_limiter_with_duplicate_id(
def test_does_not_raise_when_disabled_and_creating_rate_limiter_with_duplicate_id(
self,
):
with patch(
Expand All @@ -66,23 +56,22 @@ def test_does_not_warn_when_disabled_and_creating_rate_limiter_with_duplicate_id
rl_id = _uid()
_rate_limiter_a = MinIntervalRateLimiter.from_rpm(60, id=rl_id)
_rate_limiter_b = MinIntervalRateLimiter.from_rpm(120, id=rl_id)
assert not any("already registered" in str(w.message) for w in record)

with warnings.catch_warnings(record=True) as record:
warnings.simplefilter("always")
rl_id = _uid()
_rate_limiter_a = MinIntervalRateLimiter.from_rpm(rpm=60, id=rl_id)
_rate_limiter_b = MinIntervalRateLimiter.from_rpm(
rpm=60, max_concurrent=1, id=rl_id
)
assert not any("already registered" in str(w.message) for w in record)

with warnings.catch_warnings(record=True) as record:
warnings.simplefilter("always")
rl_id = _uid()
_rate_limiter_a = MinIntervalRateLimiter.from_rpm(rpm=60, id=rl_id)
_rate_limiter_b = CustomRateLimiter(id=rl_id)
assert not any("already registered" in str(w.message) for w in record)

duplicate_warnings = [
w for w in record if "already registered" in str(w.message)
]
assert len(duplicate_warnings) == 3
assert all(w.category is RuntimeWarning for w in duplicate_warnings)

def test_same_rate_limiter_with_same_id_should_not_raise_error(
self,
Expand Down