Source code for a_sync.executor

"""
With these executors, you can simply run sync functions in your executor with `await executor.run(fn, *args)`.

`executor.submit(fn, *args)` will work the same as the concurrent.futures implementation, but will return an asyncio.Future instead of a concurrent.futures.Future

This module provides several executor classes:
- _AsyncExecutorMixin: A mixin providing asynchronous run and submit methods.
- AsyncProcessPoolExecutor: An async process pool executor.
- AsyncThreadPoolExecutor: An async thread pool executor.
- PruningThreadPoolExecutor: A thread pool executor that prunes inactive threads after a timeout.
"""

import asyncio
import concurrent.futures as cf
import multiprocessing.context
import queue
import threading
import weakref
from concurrent.futures import _base, thread
from functools import cached_property

from a_sync._typing import *
from a_sync.primitives._debug import _DebugDaemonMixin


TEN_MINUTES = 60 * 10

Initializer = Callable[..., object]

class _AsyncExecutorMixin(cf.Executor, _DebugDaemonMixin):
    """
    A mixin for Executors to provide asynchronous run and submit methods.
    """
    _max_workers: int
    _workers: str
    __slots__ = "_max_workers", "_initializer", "_initargs", "_broken", "_shutdown_lock"

    async def run(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
        """
        A shorthand way to call `await asyncio.get_event_loop().run_in_executor(this_executor, fn, *args)`
        Doesn't `await this_executor.run(fn, *args)` look so much better?
        
        Oh, and you can also use kwargs!

        Args:
            fn (Callable[P, T]): The function to run.
            *args: Positional arguments for the function.
            **kwargs: Keyword arguments for the function.

        Returns:
            T: The result of the function.
        """
        return fn(*args, **kwargs) if self.sync_mode else await self.submit(fn, *args, **kwargs)

    def submit(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[T]":  # type: ignore [override]
        """
        Submits a job to the executor and returns an asyncio.Future that can be awaited for the result without blocking.

        Args:
            fn (Callable[P, T]): The function to submit.
            *args: Positional arguments for the function.
            **kwargs: Keyword arguments for the function.

        Returns:
            asyncio.Future[T]: The future representing the result of the function.
        """
        if self.sync_mode:
            fut = asyncio.get_event_loop().create_future()
            try:
                fut.set_result(fn(*args, **kwargs))
            except Exception as e:
                fut.set_exception(e)
        else:
            fut = asyncio.futures.wrap_future(super().submit(fn, *args, **kwargs))  # type: ignore [assignment]
            self._start_debug_daemon(fut, fn, *args, **kwargs)
        return fut

    def __repr__(self) -> str:
        return f"<{self.__class__.__name__} object at {hex(id(self))} [{self.worker_count_current}/{self._max_workers} {self._workers}]>"

    def __len__(self) -> int:
        # NOTE: should this be queue length instead? probably
        return self.worker_count_current

    @cached_property
    def sync_mode(self) -> bool:
        """
        Indicates if the executor is in synchronous mode (max_workers == 0).

        Returns:
            bool: True if in synchronous mode, False otherwise.
        """
        return self._max_workers == 0

    @property
    def worker_count_current(self) -> int:
        """
        Returns the current number of workers.

        Returns:
            int: The current number of workers.
        """
        return len(getattr(self, f"_{self._workers}"))

    async def _debug_daemon(self, fut: asyncio.Future, fn, *args, **kwargs) -> None:
        """
        Runs until manually cancelled by the finished work item.

        Args:
            fut (asyncio.Future): The future being debugged.
            fn: The function being executed.
            *args: Positional arguments for the function.
            **kwargs: Keyword arguments for the function.
        """
        while not fut.done():
            await asyncio.sleep(15)
            if not fut.done():
                self.logger.debug(f'{self} processing {fn}{args}{kwargs}')
    
# Process

[docs] class AsyncProcessPoolExecutor(_AsyncExecutorMixin, cf.ProcessPoolExecutor): """ An async process pool executor that allows use of kwargs. """ _workers = "processes" __slots__ = ("_mp_context", "_processes", "_pending_work_items", "_call_queue", "_result_queue", "_queue_management_thread", "_queue_count", "_shutdown_thread", "_work_ids", "_queue_management_thread_wakeup")
[docs] def __init__( self, max_workers: Optional[int] = None, mp_context: Optional[multiprocessing.context.BaseContext] = None, initializer: Optional[Initializer] = None, initargs: Tuple[Any, ...] = (), ) -> None: """ Initializes the AsyncProcessPoolExecutor. Args: max_workers (Optional[int], optional): The maximum number of workers. Defaults to None. mp_context (Optional[multiprocessing.context.BaseContext], optional): The multiprocessing context. Defaults to None. initializer (Optional[Initializer], optional): An initializer callable. Defaults to None. initargs (Tuple[Any, ...], optional): Arguments for the initializer. Defaults to (). """ if max_workers == 0: super().__init__(1, mp_context, initializer, initargs) self._max_workers = 0 else: super().__init__(max_workers, mp_context, initializer, initargs)
# Thread
[docs] class AsyncThreadPoolExecutor(_AsyncExecutorMixin, cf.ThreadPoolExecutor): """ An async thread pool executor that allows use of kwargs. """ _workers = "threads" __slots__ = "_work_queue", "_idle_semaphore", "_threads", "_shutdown", "_thread_name_prefix"
[docs] def __init__( self, max_workers: Optional[int] = None, thread_name_prefix: str = '', initializer: Optional[Initializer] = None, initargs: Tuple[Any, ...] = (), ) -> None: """ Initializes the AsyncThreadPoolExecutor. Args: max_workers (Optional[int], optional): The maximum number of workers. Defaults to None. thread_name_prefix (str, optional): Prefix for thread names. Defaults to ''. initializer (Optional[Initializer], optional): An initializer callable. Defaults to None. initargs (Tuple[Any, ...], optional): Arguments for the initializer. Defaults to (). """ if max_workers == 0: super().__init__(1, thread_name_prefix, initializer, initargs) self._max_workers = 0 else: super().__init__(max_workers, thread_name_prefix, initializer, initargs)
# For backward-compatibility ProcessPoolExecutor = AsyncProcessPoolExecutor ThreadPoolExecutor = AsyncThreadPoolExecutor # Pruning thread pool def _worker(executor_reference, work_queue, initializer, initargs, timeout): # NOTE: NEW 'timeout' """ Worker function for the PruningThreadPoolExecutor. Args: executor_reference: A weak reference to the executor. work_queue: The work queue. initializer: The initializer function. initargs: Arguments for the initializer. timeout: Timeout duration for pruning inactive threads. """ if initializer is not None: try: initializer(*initargs) except BaseException: _base.LOGGER.critical('Exception in initializer:', exc_info=True) executor = executor_reference() if executor is not None: executor._initializer_failed() return try: while True: try: # NOTE: NEW work_item = work_queue.get(block=True, timeout=timeout) # NOTE: NEW except queue.Empty: # NOTE: NEW # Its been 'timeout' seconds and there are no new work items. # NOTE: NEW # Let's suicide the thread. # NOTE: NEW executor = executor_reference() # NOTE: NEW with executor._adjusting_lock: # NOTE: NEW # NOTE: We keep a minimum of one thread active to prevent locks if len(executor) > 1: # NOTE: NEW t = threading.current_thread() # NOTE: NEW executor._threads.remove(t) # NOTE: NEW thread._threads_queues.pop(t) # NOTE: NEW # Let the executor know we have one less idle thread available executor._idle_semaphore.acquire(blocking=False) # NOTE: NEW return # NOTE: NEW continue if work_item is not None: work_item.run() # Delete references to object. See issue16284 del work_item # attempt to increment idle count executor = executor_reference() if executor is not None: executor._idle_semaphore.release() del executor continue executor = executor_reference() # Exit if: # - The interpreter is shutting down OR # - The executor that owns the worker has been collected OR # - The executor that owns the worker has been shutdown OR if thread._shutdown or executor is None or executor._shutdown: # Flag the executor as shutting down as early as possible if it is not gc-ed yet. if executor is not None: executor._shutdown = True # Notice other workers work_queue.put(None) return del executor except BaseException: _base.LOGGER.critical('Exception in worker', exc_info=True)
[docs] class PruningThreadPoolExecutor(AsyncThreadPoolExecutor): """ This `AsyncThreadPoolExecutor` implementation prunes inactive threads after 'timeout' seconds without a work item. Pruned threads will be automatically recreated as needed for future workloads. Up to 'max_threads' can be active at any one time. """ __slots__ = "_timeout", "_adjusting_lock"
[docs] def __init__(self, max_workers=None, thread_name_prefix='', initializer=None, initargs=(), timeout=TEN_MINUTES): """ Initializes the PruningThreadPoolExecutor. Args: max_workers (Optional[int], optional): The maximum number of workers. Defaults to None. thread_name_prefix (str, optional): Prefix for thread names. Defaults to ''. initializer (Optional[Initializer], optional): An initializer callable. Defaults to None. initargs (Tuple[Any, ...], optional): Arguments for the initializer. Defaults to (). timeout (int, optional): Timeout duration for pruning inactive threads. Defaults to TEN_MINUTES. """ self._timeout=timeout self._adjusting_lock = threading.Lock() super().__init__(max_workers, thread_name_prefix, initializer, initargs)
def __len__(self) -> int: return len(self._threads)
[docs] def _adjust_thread_count(self): """ Adjusts the number of threads based on workload and idle threads. """ with self._adjusting_lock: # if idle threads are available, don't spin new threads if self._idle_semaphore.acquire(timeout=0): return # When the executor gets lost, the weakref callback will wake up # the worker threads. def weakref_cb(_, q=self._work_queue): q.put(None) num_threads = len(self._threads) if num_threads < self._max_workers: thread_name = '%s_%d' % (self._thread_name_prefix or self, num_threads) t = threading.Thread(name=thread_name, target=_worker, args=(weakref.ref(self, weakref_cb), self._work_queue, self._initializer, self._initargs, self._timeout)) t.daemon = True t.start() self._threads.add(t) thread._threads_queues[t] = self._work_queue
executor = PruningThreadPoolExecutor(128) __all__ = [ "AsyncThreadPoolExecutor", "AsyncProcessPoolExecutor", "PruningThreadPoolExecutor", ]