Source code for a_sync.primitives.locks.counter

"""
This module provides two specialized async flow management classes, CounterLock and CounterLockCluster.

These primitives manages :class:`asyncio.Task` objects that must wait for an internal counter to reach a specific value.
"""

import asyncio
from collections import defaultdict
from time import time
from typing import DefaultDict, Iterable, Optional

from a_sync.primitives._debug import _DebugDaemonMixin
from a_sync.primitives.locks.event import Event


[docs] class CounterLock(_DebugDaemonMixin): """ An async primitive that blocks until the internal counter has reached a specific value. A coroutine can `await counter.wait_for(3)` and it will block until the internal counter >= 3. If some other task executes `counter.value = 5` or `counter.set(5)`, the first coroutine will unblock as 5 >= 3. The internal counter can only increase. """ __slots__ = "is_ready", "_name", "_value", "_events"
[docs] def __init__(self, start_value: int = 0, name: Optional[str] = None): """ Initializes the CounterLock with a starting value and an optional name. Args: start_value: The initial value of the counter. name (optional): An optional name for the counter, used in debug logs. """ self._name = name """An optional name for the counter, used in debug logs.""" self._value = start_value """The current value of the counter.""" self._events: DefaultDict[int, Event] = defaultdict(Event) """A defaultdict that maps each awaited value to an :class:`asyncio.Event` that manages the waiters for that value.""" self.is_ready = lambda v: self._value >= v """A lambda function that indicates whether a given value has already been surpassed."""
[docs] async def wait_for(self, value: int) -> bool: """ Waits until the counter reaches or exceeds the specified value. Args: value: The value to wait for. Returns: True when the counter reaches or exceeds the specified value. """ if not self.is_ready(value): self._ensure_debug_daemon() await self._events[value].wait() return True
[docs] def set(self, value: int) -> None: """ Sets the counter to the specified value. Args: value: The value to set the counter to. Must be >= the current value. Raises: ValueError: If the new value is less than the current value. """ self.value = value
def __repr__(self) -> str: waiters = {v: len(self._events[v]._waiters) for v in sorted(self._events)} return f"<CounterLock name={self._name} value={self._value} waiters={waiters}>" @property def value(self) -> int: """ Gets the current value of the counter. Returns: The current value of the counter. """ return self._value @value.setter def value(self, value: int) -> None: """ Sets the counter to a new value, waking up any waiters if the value increases beyond the value they are awaiting. Args: value: The new value of the counter. Raises: ValueError: If the new value is less than the current value. """ if value > self._value: self._value = value ready = [self._events.pop(key) for key in list(self._events.keys()) if key <= self._value] for event in ready: event.set() elif value < self._value: raise ValueError("You cannot decrease the value.")
[docs] async def _debug_daemon(self) -> None: """ Periodically logs debug information about the counter state and waiters. """ start = time() while self._events: self.logger.debug("%s is still locked after %sm", self, round(time() - start / 60, 2)) await asyncio.sleep(300)
[docs] class CounterLockCluster: """ An asyncio primitive that represents 2 or more CounterLock objects. `wait_for(i)` will block until the value of all CounterLock objects is >= i. """ __slots__ = "locks",
[docs] def __init__(self, counter_locks: Iterable[CounterLock]) -> None: """ Initializes the CounterLockCluster with a collection of CounterLock objects. Args: counter_locks: The CounterLock objects to manage. """ self.locks = list(counter_locks)
[docs] async def wait_for(self, value: int) -> bool: """ Waits until the value of all CounterLock objects in the cluster reaches or exceeds the specified value. Args: value: The value to wait for. Returns: True when the value of all CounterLock objects reach or exceed the specified value. """ await asyncio.gather(*[counter_lock.wait_for(value) for counter_lock in self.locks]) return True