Source code for a_sync.primitives.locks.counter

"""
This module provides two specialized async flow management classes, :class:`CounterLock` and :class:`CounterLockCluster`.

These primitives manage synchronization of tasks that must wait for an internal counter to reach a specific value.
"""

import asyncio
from collections import defaultdict
from time import time
from typing import DefaultDict, Iterable, Optional

from a_sync.primitives._debug import _DebugDaemonMixin
from a_sync.primitives.locks.event import Event


[docs] class CounterLock(_DebugDaemonMixin): """ An async primitive that uses an internal counter to manage task synchronization. A coroutine can `await counter.wait_for(3)` and it will wait until the internal counter >= 3. If some other task executes `counter.value = 5` or `counter.set(5)`, the first coroutine will proceed as 5 >= 3. The internal counter can only be set to a value greater than the current value. See Also: :class:`CounterLockCluster` for managing multiple :class:`CounterLock` instances. """ __slots__ = "is_ready", "_name", "_value", "_events"
[docs] def __init__(self, start_value: int = 0, name: Optional[str] = None): """ Initializes the :class:`CounterLock` with a starting value and an optional name. Args: start_value: The initial value of the counter. name: An optional name for the counter, used in debug logs. Examples: >>> counter = CounterLock(start_value=0, name="example_counter") >>> counter.value 0 """ self._name = name """An optional name for the counter, used in debug logs.""" self._value = start_value """The current value of the counter.""" self._events: DefaultDict[int, Event] = defaultdict(Event) """A defaultdict that maps each awaited value to an :class:`Event` that manages the waiters for that value.""" self.is_ready = lambda v: self._value >= v """A lambda function that indicates whether the current counter value is greater than or equal to a given value."""
[docs] async def wait_for(self, value: int) -> bool: """ Waits until the counter reaches or exceeds the specified value. This method will ensure the debug daemon is running if the counter is not ready. Args: value: The value to wait for. Examples: >>> counter = CounterLock(start_value=0) >>> await counter.wait_for(5) # This will block until counter.value >= 5 See Also: :meth:`CounterLock.set` to set the counter value. """ if not self.is_ready(value): self._ensure_debug_daemon() await self._events[value].wait() return True
[docs] def set(self, value: int) -> None: """ Sets the counter to the specified value. This method internally uses the `value` property to enforce that the new value must be strictly greater than the current value. Args: value: The value to set the counter to. Must be strictly greater than the current value. Raises: ValueError: If the new value is less than or equal to the current value. Examples: >>> counter = CounterLock(start_value=0) >>> counter.set(5) >>> counter.value 5 See Also: :meth:`CounterLock.value` for direct value assignment. """ self.value = value
def __repr__(self) -> str: """ Returns a string representation of the :class:`CounterLock` instance. The representation includes the name, current value, and the number of waiters for each awaited value. Examples: >>> counter = CounterLock(start_value=0, name="example_counter") >>> repr(counter) '<CounterLock name=example_counter value=0 waiters={}>' """ waiters = {v: len(self._events[v]._waiters) for v in sorted(self._events)} return f"<CounterLock name={self._name} value={self._value} waiters={waiters}>" @property def value(self) -> int: """ Gets the current value of the counter. Examples: >>> counter = CounterLock(start_value=0) >>> counter.value 0 """ return self._value @value.setter def value(self, value: int) -> None: """ Sets the counter to a new value, waking up any waiters if the value increases beyond the value they are awaiting. Args: value: The new value of the counter. Raises: ValueError: If the new value is less than the current value. Examples: >>> counter = CounterLock(start_value=0) >>> counter.value = 5 >>> counter.value 5 >>> counter.value = 3 Traceback (most recent call last): ... ValueError: You cannot decrease the value. """ if value > self._value: self._value = value ready = [ self._events.pop(key) for key in list(self._events.keys()) if key <= self._value ] for event in ready: event.set() elif value < self._value: raise ValueError("You cannot decrease the value.")
[docs] async def _debug_daemon(self) -> None: """ Periodically logs debug information about the counter state and waiters. This method is used internally to provide debugging information when debug logging is enabled. """ start = time() while self._events: self.logger.debug( "%s is still locked after %sm", self, round(time() - start / 60, 2) ) await asyncio.sleep(300)
[docs] class CounterLockCluster: """ An asyncio primitive that represents a collection of :class:`CounterLock` objects. `wait_for(i)` will wait until the value of all :class:`CounterLock` objects is >= i. See Also: :class:`CounterLock` for managing individual counters. """ __slots__ = ("locks",)
[docs] def __init__(self, counter_locks: Iterable[CounterLock]) -> None: """ Initializes the :class:`CounterLockCluster` with a collection of :class:`CounterLock` objects. Args: counter_locks: The :class:`CounterLock` objects to manage. Examples: >>> lock1 = CounterLock(start_value=0) >>> lock2 = CounterLock(start_value=0) >>> cluster = CounterLockCluster([lock1, lock2]) """ self.locks = list(counter_locks)
[docs] async def wait_for(self, value: int) -> bool: """ Waits until the value of all :class:`CounterLock` objects in the cluster reaches or exceeds the specified value. Args: value: The value to wait for. Examples: >>> lock1 = CounterLock(start_value=0) >>> lock2 = CounterLock(start_value=0) >>> cluster = CounterLockCluster([lock1, lock2]) >>> await cluster.wait_for(5) # This will block until all locks have value >= 5 """ await asyncio.gather( *[counter_lock.wait_for(value) for counter_lock in self.locks] ) return True