Source code for a_sync.primitives.queue

"""
This module provides various queue implementations for managing asynchronous tasks, including standard FIFO queues,
priority queues, and processing queues. These queues support advanced features like waiting for multiple items,
handling priority tasks, and processing tasks with multiple workers.
"""

import asyncio
import functools
import heapq
import logging
import sys
import weakref

from a_sync import _smart
from a_sync.asyncio.create_task import create_task
from a_sync._typing import *

logger = logging.getLogger(__name__)

if sys.version_info < (3, 9):
    class _Queue(asyncio.Queue, Generic[T]):
        __slots__ = "_maxsize", "_loop", "_getters", "_putters", "_unfinished_tasks", "_finished"
else:
[docs] class _Queue(asyncio.Queue[T]): __slots__ = "_maxsize", "_getters", "_putters", "_unfinished_tasks", "_finished"
[docs] class Queue(_Queue[T]): # for type hint support, no functional difference
[docs] async def get(self) -> T: self._queue return await _Queue.get(self)
[docs] def get_nowait(self) -> T: return _Queue.get_nowait(self)
[docs] async def put(self, item: T) -> None: return _Queue.put(self, item)
[docs] def put_nowait(self, item: T) -> None: return _Queue.put_nowait(self, item)
[docs] async def get_all(self) -> List[T]: """returns 1 or more items""" try: return self.get_all_nowait() except asyncio.QueueEmpty: return [await self.get()]
[docs] def get_all_nowait(self) -> List[T]: """returns 1 or more items, or raises asyncio.QueueEmpty""" values: List[T] = [] while True: try: values.append(self.get_nowait()) except asyncio.QueueEmpty as e: if not values: raise asyncio.QueueEmpty from e return values
[docs] async def get_multi(self, i: int, can_return_less: bool = False) -> List[T]: _validate_args(i, can_return_less) items = [] while len(items) < i and not can_return_less: try: items.extend(self.get_multi_nowait(i - len(items), can_return_less=True)) except asyncio.QueueEmpty: items = [await self.get()] return items
[docs] def get_multi_nowait(self, i: int, can_return_less: bool = False) -> List[T]: """ Just like `asyncio.Queue.get_nowait`, but will return `i` items instead of 1. Set `can_return_less` to True if you want to receive up to `i` items. """ _validate_args(i, can_return_less) items = [] for _ in range(i): try: items.append(self.get_nowait()) except asyncio.QueueEmpty: if items and can_return_less: return items # put these back in the queue since we didn't return them for value in items: self.put_nowait(value) raise asyncio.QueueEmpty from None return items
[docs] class ProcessingQueue(_Queue[Tuple[P, "asyncio.Future[V]"]], Generic[P, V]): _closed: bool = False __slots__ = "func", "num_workers", "_worker_coro"
[docs] def __init__( self, func: Callable[P, Awaitable[V]], num_workers: int, *, return_data: bool = True, name: str = "", loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: if sys.version_info < (3, 10): super().__init__(loop=loop) elif loop: raise NotImplementedError(f"You cannot pass a value for `loop` in python {sys.version_info}") else: super().__init__() self.func = func self.num_workers = num_workers self._name = name self._no_futs = not return_data @functools.wraps(func) async def _worker_coro() -> NoReturn: # we use this little helper so we can have context of `func` in any err logs return await self.__worker_coro() self._worker_coro = _worker_coro
# NOTE: asyncio defines both this and __str__ def __repr__(self) -> str: repr_string = f"<{type(self).__name__} at {hex(id(self))}" if self._name: repr_string += f" name={self._name}" repr_string += f" func={self.func} num_workers={self.num_workers}" if self._unfinished_tasks: repr_string += f" pending={self._unfinished_tasks}" return f"{repr_string}>" # NOTE: asyncio defines both this and __repr__ def __str__(self) -> str: repr_string = f"<{type(self).__name__}" if self._name: repr_string += f" name={self._name}" repr_string += f" func={self.func} num_workers={self.num_workers}" if self._unfinished_tasks: repr_string += f" pending={self._unfinished_tasks}" return f"{repr_string}>"
[docs] def __call__(self, *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[V]": return self.put_nowait(*args, **kwargs)
def __del__(self) -> None: if self._closed: return if self._unfinished_tasks > 0: context = { 'message': f'{self} was destroyed but has work pending!', } asyncio.get_event_loop().call_exception_handler(context) @property def name(self) -> str: return self._name or repr(self)
[docs] def close(self) -> None: self._closed = True
[docs] async def put(self, *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[V]": self._ensure_workers() if self._no_futs: return await super().put((args, kwargs)) fut = self._create_future() await super().put((args, kwargs, fut)) return fut
[docs] def put_nowait(self, *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[V]": self._ensure_workers() if self._no_futs: return super().put_nowait((args, kwargs)) fut = self._create_future() super().put_nowait((args, kwargs, weakref.proxy(fut))) return fut
[docs] def _create_future(self) -> "asyncio.Future[V]": return asyncio.get_event_loop().create_future()
[docs] def _ensure_workers(self) -> None: if self._closed: raise RuntimeError(f"{type(self).__name__} is closed: ", self) from None if self._workers.done(): worker_subtasks: List["asyncio.Task[NoReturn]"] = self._workers._workers for worker in worker_subtasks: if worker.done(): # its only done if its broken exc = worker.exception() # re-raise with clean traceback try: raise type(exc)(*exc.args).with_traceback(exc.__traceback__) # type: ignore [union-attr] except TypeError: raise exc.with_traceback(exc.__traceback__) # this should never be reached, but just in case exc = self._workers.exception() try: # re-raise with clean traceback raise type(exc)(*exc.args).with_traceback(exc.__traceback__) # type: ignore [union-attr] except TypeError: raise exc.with_traceback(exc.__traceback__)
@functools.cached_property def _workers(self) -> "asyncio.Task[NoReturn]": logger.debug("starting worker task for %s", self) workers = [ create_task( coro=self._worker_coro(), name=f"{self.name} [Task-{i}]", log_destroy_pending=False, ) for i in range(self.num_workers) ] task = create_task(asyncio.gather(*workers), name=f"{self.name} worker main Task", log_destroy_pending=False) task._workers = workers return task async def __worker_coro(self) -> NoReturn: args: P.args kwargs: P.kwargs if self._no_futs: while True: try: args, kwargs = await self.get() await self.func(*args, **kwargs) except Exception as e: logger.error("%s in worker for %s!", type(e).__name__, self) logger.exception(e) self.task_done() else: fut: asyncio.Future[V] while True: try: args, kwargs, fut = await self.get() try: if fut is None: # the weakref was already cleaned up, we don't need to process this item self.task_done() continue result = await self.func(*args, **kwargs) fut.set_result(result) except asyncio.exceptions.InvalidStateError: logger.error("cannot set result for %s %s: %s", self.func.__name__, fut, result) except Exception as e: try: fut.set_exception(e) except asyncio.exceptions.InvalidStateError: logger.error("cannot set exception for %s %s: %s", self.func.__name__, fut, e) self.task_done() except Exception as e: logger.error("%s for %s is broken!!!", type(self).__name__, self.func) logger.exception(e) raise
[docs] def _validate_args(i: int, can_return_less: bool) -> None: """ Validates the arguments for methods that retrieve multiple items from the queue. Args: i (int): The number of items to retrieve. can_return_less (bool): Whether the method is allowed to return fewer than `i` items. Raises: TypeError: If `i` is not an integer or `can_return_less` is not a boolean. ValueError: If `i` is not greater than 1. """ if not isinstance(i, int): raise TypeError(f"`i` must be an integer greater than 1. You passed {i}") if not isinstance(can_return_less, bool): raise TypeError(f"`can_return_less` must be boolean. You passed {can_return_less}") if i <= 1: raise ValueError(f"`i` must be an integer greater than 1. You passed {i}")
[docs] class _SmartFutureRef(weakref.ref, Generic[T]): def __lt__(self, other: "_SmartFutureRef[T]") -> bool: """ Compares two weak references to SmartFuture objects for ordering. This comparison is used in priority queues to determine the order of processing. A SmartFuture reference is considered less than another if it has more waiters or if it has been garbage collected. Args: other (_SmartFutureRef[T]): The other SmartFuture reference to compare with. Returns: bool: True if this reference is less than the other, False otherwise. """ strong_self = self() if strong_self is None: return True strong_other = other() if strong_other is None: return False return strong_self < strong_other
[docs] class _PriorityQueueMixin(Generic[T]):
[docs] def _init(self, maxsize): self._queue: List[T] = []
[docs] def _put(self, item, heappush=heapq.heappush): heappush(self._queue, item)
[docs] def _get(self, heappop=heapq.heappop): return heappop(self._queue)
[docs] class PriorityProcessingQueue(_PriorityQueueMixin[T], ProcessingQueue[T, V]): # NOTE: WIP
[docs] async def put(self, priority: Any, *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[V]": self._ensure_workers() fut = asyncio.get_event_loop().create_future() await super().put(self, (priority, args, kwargs, fut)) return fut
[docs] def put_nowait(self, priority: Any, *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[V]": self._ensure_workers() fut = self._create_future() super().put_nowait(self, (priority, args, kwargs, fut)) return fut
[docs] def _get(self, heappop=heapq.heappop): priority, args, kwargs, fut = heappop(self._queue) return args, kwargs, fut
[docs] class _VariablePriorityQueueMixin(_PriorityQueueMixin[T]):
[docs] def _get(self, heapify=heapq.heapify, heappop=heapq.heappop): "Resort the heap to consider any changes in priorities and pop the smallest value" # resort the heap heapify(self._queue) # take the job with the most waiters return heappop(self._queue)
[docs] def _get_key(self, *args, **kwargs) -> _smart._Key: return (args, tuple((kwarg, kwargs[kwarg]) for kwarg in sorted(kwargs)))
[docs] class VariablePriorityQueue(_VariablePriorityQueueMixin[T], asyncio.PriorityQueue): """A PriorityQueue subclass that allows priorities to be updated (or computed) on the fly"""
# NOTE: WIP
[docs] class SmartProcessingQueue(_VariablePriorityQueueMixin[T], ProcessingQueue[Concatenate[T, P], V]): """A PriorityProcessingQueue subclass that will execute jobs with the most waiters first""" _no_futs = False _futs: "weakref.WeakValueDictionary[_smart._Key, _smart.SmartFuture[T]]"
[docs] def __init__( self, func: Callable[Concatenate[T, P], Awaitable[V]], num_workers: int, *, name: str = "", loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: super().__init__(func, num_workers, return_data=True, name=name, loop=loop) self._futs: Dict[_smart._Key[T], _smart.SmartFuture[T]] = weakref.WeakValueDictionary()
[docs] async def put(self, *args: P.args, **kwargs: P.kwargs) -> _smart.SmartFuture[V]: self._ensure_workers() key = self._get_key(*args, **kwargs) if fut := self._futs.get(key, None): return fut fut = self._create_future(key) self._futs[key] = fut await Queue.put(self, (_SmartFutureRef(fut), args, kwargs)) return fut
[docs] def put_nowait(self, *args: P.args, **kwargs: P.kwargs) -> _smart.SmartFuture[V]: self._ensure_workers() key = self._get_key(*args, **kwargs) if fut := self._futs.get(key, None): return fut fut = self._create_future(key) self._futs[key] = fut Queue.put_nowait(self, (_SmartFutureRef(fut), args, kwargs)) return fut
[docs] def _create_future(self, key: _smart._Key) -> "asyncio.Future[V]": return _smart.create_future(queue=self, key=key, loop=self._loop)
[docs] def _get(self): fut, args, kwargs = super()._get() return args, kwargs, fut()
async def __worker_coro(self) -> NoReturn: args: P.args kwargs: P.kwargs fut: _smart.SmartFuture[V] while True: try: try: args, kwargs, fut = await self.get() if fut is None: # the weakref was already cleaned up, we don't need to process this item self.task_done() continue logger.debug("processing %s", fut) result = await self.func(*args, **kwargs) fut.set_result(result) except asyncio.exceptions.InvalidStateError: logger.error("cannot set result for %s %s: %s", self.func.__name__, fut, result) except Exception as e: logger.debug("%s: %s", type(e).__name__, e) try: fut.set_exception(e) except asyncio.exceptions.InvalidStateError: logger.error("cannot set exception for %s %s: %s", self.func.__name__, fut, e) self.task_done() except Exception as e: logger.error("%s for %s is broken!!!", type(self).__name__, self.func) logger.exception(e) raise