Source code for a_sync.primitives.locks.prio_semaphore

"""
This module provides priority-based semaphore implementations. These semaphores allow 
waiters to be assigned priorities, ensuring that higher priority waiters are 
processed before lower priority ones.
"""

import asyncio
import heapq
import logging
from collections import deque
from functools import cached_property

from a_sync._typing import *
from a_sync.primitives.locks.semaphore import Semaphore

logger = logging.getLogger(__name__)


[docs] class Priority(Protocol): def __lt__(self, other) -> bool: ...
PT = TypeVar("PT", bound=Priority) CM = TypeVar("CM", bound="_AbstractPrioritySemaphoreContextManager[Priority]")
[docs] class _AbstractPrioritySemaphore(Semaphore, Generic[PT, CM]): """ A semaphore that allows prioritization of waiters. This semaphore manages waiters with associated priorities, ensuring that waiters with higher priorities are processed before those with lower priorities. Subclasses must define the `_top_priority` attribute to specify the default top priority behavior. The `_context_manager_class` attribute should specify the class used for managing semaphore contexts. See Also: :class:`PrioritySemaphore` for an implementation using numeric priorities. """ name: Optional[str] _value: int _waiters: List["_AbstractPrioritySemaphoreContextManager[PT]"] # type: ignore [assignment] _context_managers: Dict[PT, "_AbstractPrioritySemaphoreContextManager[PT]"] __slots__ = ( "name", "_value", "_waiters", "_context_managers", "_capacity", "_potential_lost_waiters", ) @property def _context_manager_class( self, ) -> Type["_AbstractPrioritySemaphoreContextManager[PT]"]: raise NotImplementedError @property def _top_priority(self) -> PT: """Defines the top priority for the semaphore. Subclasses must implement this property to specify the default top priority. Raises: NotImplementedError: If not implemented in a subclass. """ raise NotImplementedError
[docs] def __init__(self, value: int = 1, *, name: Optional[str] = None) -> None: """Initializes the priority semaphore. Args: value: The initial capacity of the semaphore. name: An optional name for the semaphore, used for debugging. Examples: >>> semaphore = _AbstractPrioritySemaphore(5, name="test_semaphore") """ self._context_managers = {} """A dictionary mapping priorities to their context managers.""" self._capacity = value """The initial capacity of the semaphore.""" super().__init__(value, name=name) self._waiters = [] """A heap queue of context managers, sorted by priority.""" # NOTE: This should (hopefully) be temporary self._potential_lost_waiters: List["asyncio.Future[None]"] = [] """A list of futures representing waiters that might have been lost."""
def __repr__(self) -> str: """Returns a string representation of the semaphore.""" return f"<{self.__class__.__name__} name={self.name} capacity={self._capacity} value={self._value} waiters={self._count_waiters()}>" async def __aenter__(self) -> None: """Enters the semaphore context, acquiring it with the top priority. This method is part of the asynchronous context management protocol. Examples: >>> semaphore = _AbstractPrioritySemaphore(5) >>> async with semaphore: ... await do_stuff() """ await self[self._top_priority].acquire() async def __aexit__(self, *_) -> None: """Exits the semaphore context, releasing it with the top priority. This method is part of the asynchronous context management protocol. Examples: >>> semaphore = _AbstractPrioritySemaphore(5) >>> async with semaphore: ... await do_stuff() """ self[self._top_priority].release()
[docs] async def acquire(self) -> Literal[True]: """Acquires the semaphore with the top priority. This method overrides :meth:`Semaphore.acquire` to handle priority-based logic. Examples: >>> semaphore = _AbstractPrioritySemaphore(5) >>> await semaphore.acquire() """ return await self[self._top_priority].acquire()
[docs] def __getitem__( self, priority: Optional[PT] ) -> "_AbstractPrioritySemaphoreContextManager[PT]": """Gets the context manager for a given priority. Args: priority: The priority for which to get the context manager. If None, uses the top priority. Returns: The context manager associated with the given priority. Examples: >>> semaphore = _AbstractPrioritySemaphore(5) >>> context_manager = semaphore[priority] """ priority = self._top_priority if priority is None else priority if priority not in self._context_managers: context_manager = self._context_manager_class( self, priority, name=self.name ) heapq.heappush(self._waiters, context_manager) # type: ignore [misc] self._context_managers[priority] = context_manager return self._context_managers[priority]
[docs] def locked(self) -> bool: """Checks if the semaphore is locked. Returns: True if the semaphore cannot be acquired immediately, False otherwise. Examples: >>> semaphore = _AbstractPrioritySemaphore(5) >>> semaphore.locked() """ return self._value == 0 or ( any( cm._waiters and any(not w.cancelled() for w in cm._waiters) for cm in (self._context_managers.values() or ()) ) )
[docs] def _count_waiters(self) -> Dict[PT, int]: """Counts the number of waiters for each priority. Returns: A dictionary mapping each priority to the number of waiters. Examples: >>> semaphore = _AbstractPrioritySemaphore(5) >>> semaphore._count_waiters() """ return { manager._priority: len(manager.waiters) for manager in sorted(self._waiters, key=lambda m: m._priority) }
[docs] def _wake_up_next(self) -> None: """Wakes up the next waiter in line. This method handles the waking of waiters based on priority. It includes an emergency procedure to handle potential lost waiters, ensuring that no waiter is left indefinitely waiting. The emergency procedure is a temporary measure to address potential issues with lost waiters. Examples: >>> semaphore = _AbstractPrioritySemaphore(5) >>> semaphore._wake_up_next() """ while self._waiters: manager = heapq.heappop(self._waiters) if len(manager) == 0: # There are no more waiters, get rid of the empty manager logger.debug( "manager %s has no more waiters, popping from %s", manager._repr_no_parent_(), self, ) self._context_managers.pop(manager._priority) continue logger.debug("waking up next for %s", manager._repr_no_parent_()) woke_up = False start_len = len(manager) if not manager._waiters: logger.debug("not manager._waiters") while manager._waiters: waiter = manager._waiters.popleft() self._potential_lost_waiters.remove(waiter) if not waiter.done(): waiter.set_result(None) logger.debug("woke up %s", waiter) woke_up = True break if not woke_up: self._context_managers.pop(manager._priority) continue end_len = len(manager) assert start_len > end_len, f"start {start_len} end {end_len}" if end_len: # There are still waiters, put the manager back heapq.heappush(self._waiters, manager) # type: ignore [misc] else: # There are no more waiters, get rid of the empty manager self._context_managers.pop(manager._priority) return # emergency procedure (hopefully temporary): while self._potential_lost_waiters: waiter = self._potential_lost_waiters.pop(0) logger.debug("we found a lost waiter %s", waiter) if not waiter.done(): waiter.set_result(None) logger.debug("woke up lost waiter %s", waiter) return logger.debug("%s has no waiters to wake", self)
[docs] class _AbstractPrioritySemaphoreContextManager(Semaphore, Generic[PT]): """ A context manager for priority semaphore waiters. This context manager is associated with a specific priority and handles the acquisition and release of the semaphore for waiters with that priority. """ _loop: asyncio.AbstractEventLoop _waiters: Deque[asyncio.Future] # type: ignore [assignment] __slots__ = "_parent", "_priority" @property def _priority_name(self) -> str: raise NotImplementedError
[docs] def __init__( self, parent: _AbstractPrioritySemaphore, priority: PT, name: Optional[str] = None, ) -> None: """Initializes the context manager for a specific priority. Args: parent: The parent semaphore. priority: The priority associated with this context manager. name: An optional name for the context manager, used for debugging. Examples: >>> parent_semaphore = _AbstractPrioritySemaphore(5) >>> context_manager = _AbstractPrioritySemaphoreContextManager(parent_semaphore, priority=1) """ self._parent = parent """The parent semaphore.""" self._priority = priority """The priority associated with this context manager.""" super().__init__(0, name=name)
def __repr__(self) -> str: """Returns a string representation of the context manager.""" return f"<{self.__class__.__name__} parent={self._parent} {self._priority_name}={self._priority} waiters={len(self)}>"
[docs] def _repr_no_parent_(self) -> str: """Returns a string representation of the context manager without the parent.""" return f"<{self.__class__.__name__} parent_name={self._parent.name} {self._priority_name}={self._priority} waiters={len(self)}>"
def __lt__(self, other) -> bool: """Compares this context manager with another based on priority. Args: other: The other context manager to compare with. Returns: True if this context manager has a lower priority than the other, False otherwise. Raises: TypeError: If the other object is not of the same type. Examples: >>> cm1 = _AbstractPrioritySemaphoreContextManager(parent, priority=1) >>> cm2 = _AbstractPrioritySemaphoreContextManager(parent, priority=2) >>> cm1 < cm2 """ if type(other) is not type(self): raise TypeError(f"{other} is not type {self.__class__.__name__}") return self._priority < other._priority @cached_property def loop(self) -> asyncio.AbstractEventLoop: """Gets the event loop associated with this context manager.""" return self._loop or asyncio.get_event_loop() @property def waiters(self) -> Deque[asyncio.Future]: """Gets the deque of waiters for this context manager.""" if self._waiters is None: self._waiters = deque() return self._waiters
[docs] async def acquire(self) -> Literal[True]: """Acquires the semaphore for this context manager. If the internal counter is larger than zero on entry, decrement it by one and return True immediately. If it is zero on entry, block, waiting until some other coroutine has called release() to make it larger than 0, and then return True. This method overrides :meth:`Semaphore.acquire` to handle priority-based logic. Examples: >>> context_manager = _AbstractPrioritySemaphoreContextManager(parent, priority=1) >>> await context_manager.acquire() """ if self._parent._value <= 0: self._ensure_debug_daemon() while self._parent._value <= 0: fut = self.loop.create_future() self.waiters.append(fut) self._parent._potential_lost_waiters.append(fut) try: await fut except: # See the similar code in Queue.get. fut.cancel() if self._parent._value > 0 and not fut.cancelled(): self._parent._wake_up_next() raise self._parent._value -= 1 return True
[docs] def release(self) -> None: """Releases the semaphore for this context manager. This method overrides :meth:`Semaphore.release` to handle priority-based logic. Examples: >>> context_manager = _AbstractPrioritySemaphoreContextManager(parent, priority=1) >>> context_manager.release() """ self._parent.release()
[docs] class _PrioritySemaphoreContextManager( _AbstractPrioritySemaphoreContextManager[Numeric] ): """Context manager for numeric priority semaphores.""" _priority_name = "priority"
[docs] class PrioritySemaphore(_AbstractPrioritySemaphore[Numeric, _PrioritySemaphoreContextManager]): # type: ignore [type-var] """Semaphore that uses numeric priorities for waiters. This class extends :class:`_AbstractPrioritySemaphore` and provides a concrete implementation using numeric priorities. The `_context_manager_class` is set to :class:`_PrioritySemaphoreContextManager`, and the `_top_priority` is set to -1, which is the highest priority. Examples: The primary way to use this semaphore is by specifying a priority. >>> priority_semaphore = PrioritySemaphore(10) >>> async with priority_semaphore[priority]: ... await do_stuff() You can also enter and exit this semaphore without specifying a priority, and it will use the top priority by default: >>> priority_semaphore = PrioritySemaphore(10) >>> async with priority_semaphore: ... await do_stuff() See Also: :class:`_AbstractPrioritySemaphore` for the base class implementation. """ _context_manager_class = _PrioritySemaphoreContextManager _top_priority = -1