import asyncio
import contextlib
import functools
import inspect
import logging
import weakref
from a_sync import exceptions
from a_sync.asyncio.create_task import create_task
from a_sync._typing import *
from a_sync.a_sync import _kwargs
from a_sync.a_sync.base import ASyncGenericBase
from a_sync.a_sync.function import ASyncFunction
from a_sync.a_sync.method import ASyncBoundMethod, ASyncMethodDescriptor, ASyncMethodDescriptorSyncDefault
from a_sync.a_sync.property import _ASyncPropertyDescriptorBase
from a_sync.asyncio.as_completed import as_completed
from a_sync.asyncio.gather import Excluder, gather
from a_sync.iter import ASyncIterator, ASyncGeneratorFunction, ASyncSorter
from a_sync.primitives.queue import Queue, ProcessingQueue
from a_sync.primitives.locks.event import Event
from a_sync.utils.iterators import as_yielded, exhaust_iterator
logger = logging.getLogger(__name__)
MappingFn = Callable[Concatenate[K, P], Awaitable[V]]
[docs]
class TaskMapping(DefaultDict[K, "asyncio.Task[V]"], AsyncIterable[Tuple[K, V]]):
"""
A mapping from keys to asyncio Tasks that asynchronously generates and manages tasks based on input iterables.
"""
concurrency: Optional[int] = None
"The max number of tasks that will run at one time."
_destroyed: bool = False
"Boolean indicating whether his mapping has been consumed and is no longer usable for aggregations."
_init_loader: Optional["asyncio.Task[None]"] = None
"An asyncio Task used to preload values from the iterables."
_init_loader_next: Optional[Callable[[], Awaitable[Tuple[Tuple[K, "asyncio.Task[V]"]]]]] = None
"A coro function that blocks until the _init_loader starts a new task(s), and then returns a `Tuple[Tuple[K, asyncio.Task[V]]]` with all of the new tasks and the keys that started them."
_name: Optional[str] = None
"Optional name for tasks created by this mapping."
_next: Event = None
"An asyncio Event that indicates the next result is ready"
_wrapped_func_kwargs: Dict[str, Any] = {}
"Additional keyword arguments passed to `_wrapped_func`."
__iterables__: Tuple[AnyIterableOrAwaitableIterable[K], ...] = ()
"The original iterables, if any, used to initialize this mapping."""
__init_loader_coro: Optional[Awaitable[None]] = None
"""An optional asyncio Coroutine to be run by the `_init_loader`"""
__slots__ = "_wrapped_func", "__wrapped__", "__dict__", "__weakref__"
# NOTE: maybe since we use so many classvars here we are better off getting rid of slots
[docs]
def __init__(
self,
wrapped_func: MappingFn[K, P, V] = None,
*iterables: AnyIterableOrAwaitableIterable[K],
name: str = '',
concurrency: Optional[int] = None,
**wrapped_func_kwargs: P.kwargs,
) -> None:
"""
Args:
wrapped_func: A function that takes a key (and optional parameters) and returns an Awaitable.
*iterables: Any number of iterables whose elements will be used as keys for task generation.
name: An optional name for the tasks created by this mapping.
**wrapped_func_kwargs: Keyword arguments that will be passed to `wrapped_func`.
"""
if concurrency:
self.concurrency = concurrency
self.__wrapped__ = wrapped_func
"The original callable used to initialize this mapping without any modifications."""
if iterables:
self.__iterables__ = iterables
wrapped_func = _unwrap(wrapped_func)
self._wrapped_func = wrapped_func
"The function used to create tasks for each key."
if isinstance(wrapped_func, ASyncMethodDescriptor):
if _kwargs.get_flag_name(wrapped_func_kwargs) is None:
wrapped_func_kwargs["sync"] = False
if wrapped_func_kwargs:
self._wrapped_func_kwargs = wrapped_func_kwargs
if name:
self._name = name
if iterables:
self._next = Event(name=f"{self} `_next`")
@functools.wraps(wrapped_func)
async def _wrapped_set_next(*args: P.args, __a_sync_recursion: int = 0, **kwargs: P.kwargs) -> V:
try:
return await wrapped_func(*args, **kwargs)
except exceptions.SyncModeInAsyncContextError as e:
raise Exception(e, self.__wrapped__)
except TypeError as e:
if __a_sync_recursion > 2 or not (str(e).startswith(wrapped_func.__name__) and "got multiple values for argument" in str(e)):
raise
# NOTE: args ordering is clashing with provided kwargs. We can handle this in a hacky way.
# TODO: perform this check earlier and pre-prepare the args/kwargs ordering
new_args = list(args)
new_kwargs = dict(kwargs)
try:
for i, arg in enumerate(inspect.getfullargspec(self.__wrapped__).args):
if arg in kwargs:
new_args.insert(i, new_kwargs.pop(arg))
else:
break
return await _wrapped_set_next(*new_args, **new_kwargs, __a_sync_recursion=__a_sync_recursion+1)
except TypeError as e2:
raise e.with_traceback(e.__traceback__) if str(e2) == "unsupported callable" else e2.with_traceback(e2.__traceback__)
finally:
self._next.set()
self._next.clear()
self._wrapped_func = _wrapped_set_next
init_loader_queue: Queue[Tuple[K, "asyncio.Future[V]"]] = Queue()
self.__init_loader_coro = exhaust_iterator(self._tasks_for_iterables(*iterables), queue=init_loader_queue)
with contextlib.suppress(_NoRunningLoop):
# its okay if we get this exception, we can start the task as soon as the loop starts
self._init_loader
self._init_loader_next = init_loader_queue.get_all
def __repr__(self) -> str:
return f"<{type(self).__name__} for {self._wrapped_func} kwargs={self._wrapped_func_kwargs} tasks={len(self)} at {hex(id(self))}>"
def __hash__(self) -> int:
return id(self)
def __setitem__(self, item: Any, value: Any) -> None:
raise NotImplementedError("You cannot manually set items in a TaskMapping")
[docs]
def __getitem__(self, item: K) -> "asyncio.Task[V]":
try:
return super().__getitem__(item)
except KeyError:
if self.concurrency:
# NOTE: we use a queue instead of a Semaphore to reduce memory use for use cases involving many many tasks
fut = self._queue.put_nowait(item)
else:
coro = self._wrapped_func(item, **self._wrapped_func_kwargs)
name = f"{self._name}[{item}]" if self._name else f"{item}",
fut = create_task(coro=coro, name=name)
super().__setitem__(item, fut)
return fut
def __await__(self) -> Generator[Any, None, Dict[K, V]]:
"""Wait for all tasks to complete and return a dictionary of the results."""
return self.gather(sync=False).__await__()
[docs]
async def __aiter__(self, pop: bool = False) -> AsyncIterator[Tuple[K, V]]:
"""aiterate thru all key-task pairs, yielding the key-result pair as each task completes"""
self._if_pop_check_destroyed(pop)
# if you inited the TaskMapping with some iterators, we will load those
yielded = set()
try:
if self._init_loader is None:
# if you didn't init the TaskMapping with iterators and you didn't start any tasks manually, we should fail
self._raise_if_empty()
else:
while not self._init_loader.done():
await self._wait_for_next_key()
while unyielded := [key for key in self if key not in yielded]:
if ready := {key: task for key in unyielded if (task:=self[key]).done()}:
if pop:
for key, task in ready.items():
yield key, await self.pop(key)
yielded.add(key)
else:
for key, task in ready.items():
yield key, await task
yielded.add(key)
else:
await self._next.wait()
# loader is already done by this point, but we need to check for exceptions
await self._init_loader
# if there are any tasks that still need to complete, we need to await them and yield them
if unyielded := {key: self[key] for key in self if key not in yielded}:
if pop:
async for key, value in as_completed(unyielded, aiter=True):
self.pop(key)
yield key, value
else:
async for key, value in as_completed(unyielded, aiter=True):
yield key, value
finally:
await self._if_pop_clear(pop)
def __delitem__(self, item: K) -> None:
task_or_fut = dict.__getitem__(self, item)
if not task_or_fut.done():
task_or_fut.cancel()
super().__delitem__(item)
[docs]
def keys(self, pop: bool = False) -> "TaskMappingKeys[K, V]":
return TaskMappingKeys(super().keys(), self, pop=pop)
[docs]
def values(self, pop: bool = False) -> "TaskMappingValues[K, V]":
return TaskMappingValues(super().values(), self, pop=pop)
[docs]
def items(self, pop: bool = False) -> "TaskMappingValues[K, V]":
return TaskMappingItems(super().items(), self, pop=pop)
[docs]
async def close(self) -> None:
await self._if_pop_clear(True)
[docs]
@ASyncGeneratorFunction
async def map(self, *iterables: AnyIterableOrAwaitableIterable[K], pop: bool = True, yields: Literal['keys', 'both'] = 'both') -> AsyncIterator[Tuple[K, V]]:
"""
Asynchronously map iterables to tasks and yield their results.
Args:
*iterables: Iterables to map over.
pop: Whether to remove tasks from the internal storage once they are completed.
yields: Whether to yield 'keys', 'values', or 'both' (key-value pairs).
Yields:
Depending on `yields`, either keys, values,
or tuples of key-value pairs representing the results of completed tasks.
"""
self._if_pop_check_destroyed(pop)
# make sure the init loader is started if needed
init_loader = self._init_loader
if iterables and init_loader:
raise ValueError("You cannot pass `iterables` to map if the TaskMapping was initialized with an (a)iterable.")
try:
if iterables:
self._raise_if_not_empty()
try:
async for _ in self._tasks_for_iterables(*iterables):
async for key, value in self.yield_completed(pop=pop):
yield _yield(key, value, yields)
except _EmptySequenceError:
if len(iterables) > 1:
# TODO gotta handle this situation
raise exceptions.EmptySequenceError("bob needs to code something so you can do this, go tell him") from None
# just pass thru
elif init_loader:
# check for exceptions if you passed an iterable(s) into the class init
await init_loader
else:
self._raise_if_empty("You must either initialize your TaskMapping with an iterable(s) or provide them during your call to map")
if self:
if pop:
async for key, value in as_completed(self, aiter=True):
self.pop(key)
yield _yield(key, value, yields)
else:
async for key, value in as_completed(self, aiter=True):
yield _yield(key, value, yields)
finally:
await self._if_pop_clear(pop)
[docs]
@ASyncMethodDescriptorSyncDefault
async def all(self, pop: bool = True) -> bool:
try:
async for key, result in self.__aiter__(pop=pop):
if not bool(result):
return False
return True
except _EmptySequenceError:
return True
finally:
await self._if_pop_clear(pop)
[docs]
@ASyncMethodDescriptorSyncDefault
async def any(self, pop: bool = True) -> bool:
try:
async for key, result in self.__aiter__(pop=pop):
if bool(result):
return True
return False
except _EmptySequenceError:
return False
finally:
await self._if_pop_clear(pop)
[docs]
@ASyncMethodDescriptorSyncDefault
async def max(self, pop: bool = True) -> V:
max = None
try:
async for key, result in self.__aiter__(pop=pop):
if max is None or result > max:
max = result
except _EmptySequenceError:
raise exceptions.EmptySequenceError("max() arg is an empty sequence") from None
if max is None:
raise exceptions.EmptySequenceError("max() arg is an empty sequence") from None
return max
[docs]
@ASyncMethodDescriptorSyncDefault
async def min(self, pop: bool = True) -> V:
min = None
try:
async for key, result in self.__aiter__(pop=pop):
if min is None or result < min:
min = result
except _EmptySequenceError:
raise exceptions.EmptySequenceError("min() arg is an empty sequence") from None
if min is None:
raise exceptions.EmptySequenceError("min() arg is an empty sequence") from None
return min
[docs]
@ASyncMethodDescriptorSyncDefault
async def sum(self, pop: bool = False) -> V:
retval = 0
try:
async for key, result in self.__aiter__(pop=pop):
retval += result
except _EmptySequenceError:
return 0
return retval
[docs]
@ASyncIterator.wrap
async def yield_completed(self, pop: bool = True) -> AsyncIterator[Tuple[K, V]]:
"""
Asynchronously yield tuples of key-value pairs representing the results of any completed tasks.
Args:
pop: Whether to remove tasks from the internal storage once they are completed.
Yields:
Tuples of key-value pairs representing the results of completed tasks.
"""
if pop:
for k, task in dict(self).items():
if task.done():
yield k, await self.pop(k)
else:
for k, task in dict(self).items():
if task.done():
yield k, await task
[docs]
@ASyncMethodDescriptorSyncDefault
async def gather(
self,
return_exceptions: bool = False,
exclude_if: Excluder[V] = None,
tqdm: bool = False,
**tqdm_kwargs: Any,
) -> Dict[K, V]:
"""Wait for all tasks to complete and return a dictionary of the results."""
if self._init_loader:
await self._init_loader
self._raise_if_empty()
return await gather(self, return_exceptions=return_exceptions, exclude_if=exclude_if, tqdm=tqdm, **tqdm_kwargs)
@overload
def pop(self, item: K, cancel: bool = False) -> "Union[asyncio.Task[V], asyncio.Future[V]]":...
@overload
def pop(self, item: K, default: K, cancel: bool = False) -> "Union[asyncio.Task[V], asyncio.Future[V]]":...
[docs]
def pop(self, *args: K, cancel: bool = False) -> "Union[asyncio.Task[V], asyncio.Future[V]]":
fut_or_task = super().pop(*args)
if cancel:
fut_or_task.cancel()
return fut_or_task
[docs]
def clear(self, cancel: bool = False) -> None:
if cancel and self._init_loader and not self._init_loader.done():
logger.debug("cancelling %s", self._init_loader)
# temporary, remove later
try:
raise Exception
except Exception as e:
logger.exception(e)
self._init_loader.cancel()
if keys := tuple(self.keys()):
logger.debug("popping remaining %s tasks", self)
for k in keys:
self.pop(k, cancel=cancel)
@functools.cached_property
def _init_loader(self) -> Optional["asyncio.Task[None]"]:
if self.__init_loader_coro:
logger.debug("starting %s init loader", self)
name=f"{type(self).__name__} init loader loading {self.__iterables__} for {self}"
try:
task = create_task(coro=self.__init_loader_coro, name=name)
except RuntimeError as e:
raise _NoRunningLoop if str(e) == "no running event loop" else e
task.add_done_callback(self.__cleanup)
return task
@functools.cached_property
def _queue(self) -> ProcessingQueue:
fn = functools.partial(self._wrapped_func, **self._wrapped_func_kwargs)
return ProcessingQueue(fn, self.concurrency, name=self._name)
[docs]
def _raise_if_empty(self, msg: str = '') -> None:
if not self:
raise exceptions.MappingIsEmptyError(self, msg)
[docs]
def _raise_if_not_empty(self) -> None:
if self:
raise exceptions.MappingNotEmptyError(self)
[docs]
@ASyncGeneratorFunction
async def _tasks_for_iterables(self, *iterables: AnyIterableOrAwaitableIterable[K]) -> AsyncIterator[Tuple[K, "asyncio.Task[V]"]]:
"""Ensure tasks are running for each key in the provided iterables."""
# if we have any regular containers we can yield their contents right away
containers = [iterable for iterable in iterables if not isinstance(iterable, AsyncIterable) and isinstance(iterable, Iterable)]
for iterable in containers:
async for key in _yield_keys(iterable):
yield key, self[key]
if remaining := [iterable for iterable in iterables if iterable not in containers]:
try:
async for key in as_yielded(*[_yield_keys(iterable) for iterable in remaining]): # type: ignore [attr-defined]
yield key, self[key] # ensure task is running
except _EmptySequenceError:
if len(iterables) == 1:
raise
raise RuntimeError("DEV: figure out how to handle this situation") from None
[docs]
def _if_pop_check_destroyed(self, pop: bool) -> None:
if pop:
if self._destroyed:
raise RuntimeError(f"{self} has already been consumed")
self._destroyed = True
[docs]
async def _if_pop_clear(self, pop: bool) -> None:
if pop:
self._destroyed = True
# _queue is a cached_property, we don't want to create it if it doesn't exist
if self.concurrency and '_queue' in self.__dict__:
self._queue.close()
del self._queue
self.clear(cancel=True)
# we need to let the loop run once so the tasks can fully cancel
await asyncio.sleep(0)
[docs]
async def _wait_for_next_key(self) -> None:
# NOTE if `_init_loader` has an exception it will return first, otherwise `_init_loader_next` will return always
done, pending = await asyncio.wait(
[create_task(self._init_loader_next(), log_destroy_pending=False), self._init_loader],
return_when=asyncio.FIRST_COMPLETED
)
for task in done:
# check for exceptions
await task
def __cleanup(self, t: "asyncio.Task[None]") -> None:
# clear the slot and let the bound Queue die
del self.__init_loader_coro
class _NoRunningLoop(Exception):
...
@overload
def _yield(key: K, value: V, yields: Literal['keys']) -> K:...
@overload
def _yield(key: K, value: V, yields: Literal['both']) -> Tuple[K, V]:...
def _yield(key: K, value: V, yields: Literal['keys', 'both']) -> Union[K, Tuple[K, V]]:
"""
Yield either the key, value, or both based on the 'yields' parameter.
Args:
key: The key of the task.
value: The result of the task.
yields: Determines what to yield; 'keys' for keys, 'both' for key-value pairs.
Returns:
The key, the value, or a tuple of both based on the 'yields' parameter.
"""
if yields == 'both':
return key, value
elif yields == 'keys':
return key
else:
raise ValueError(f"`yields` must be 'keys' or 'both'. You passed {yields}")
class _EmptySequenceError(ValueError):
...
async def _yield_keys(iterable: AnyIterableOrAwaitableIterable[K]) -> AsyncIterator[K]:
"""
Asynchronously yield keys from the provided iterable.
Args:
iterable: An iterable that can be either synchronous or asynchronous.
Yields:
Keys extracted from the iterable.
"""
if not iterable:
raise _EmptySequenceError(iterable)
elif isinstance(iterable, AsyncIterable):
async for key in iterable:
yield key
elif isinstance(iterable, Iterable):
for key in iterable:
yield key
elif inspect.isawaitable(iterable):
async for key in _yield_keys(await iterable):
yield key
else:
raise TypeError(iterable)
__unwrapped = weakref.WeakKeyDictionary()
def _unwrap(wrapped_func: Union[AnyFn[P, T], "ASyncMethodDescriptor[P, T]", _ASyncPropertyDescriptorBase[I, T]]) -> Callable[P, Awaitable[T]]:
if unwrapped := __unwrapped.get(wrapped_func):
return unwrapped
if isinstance(wrapped_func, (ASyncBoundMethod, ASyncMethodDescriptor)):
unwrapped = wrapped_func
elif isinstance(wrapped_func, _ASyncPropertyDescriptorBase):
unwrapped = wrapped_func.get
elif isinstance(wrapped_func, ASyncFunction):
# this speeds things up a bit by bypassing some logic
# TODO implement it like this elsewhere if profilers suggest
unwrapped = wrapped_func._modified_fn if wrapped_func._async_def else wrapped_func._asyncified
else:
unwrapped = wrapped_func
__unwrapped[wrapped_func] = unwrapped
return unwrapped
_get_key: Callable[[Tuple[K, V]], K] = lambda k_and_v: k_and_v[0]
_get_value: Callable[[Tuple[K, V]], V] = lambda k_and_v: k_and_v[1]
class _TaskMappingView(ASyncGenericBase, Iterable[T], Generic[T, K, V]):
_get_from_item: Callable[[Tuple[K, V]], T]
_pop: bool = False
def __init__(self, view: Iterable[T], task_mapping: TaskMapping[K, V], pop: bool = False) -> None:
self.__view__ = view
self.__mapping__: TaskMapping = weakref.proxy(task_mapping)
"actually a weakref.ProxyType[TaskMapping] but then type hints weren't working"
if pop:
self._pop = True
def __iter__(self) -> Iterator[T]:
return iter(self.__view__)
def __await__(self) -> Generator[Any, None, List[T]]:
return self._await().__await__()
def __len__(self) -> int:
return len(self.__view__)
async def _await(self) -> List[T]:
return [result async for result in self]
__slots__ = "__view__", "__mapping__"
async def aiterbykeys(self, reverse: bool = False) -> ASyncIterator[T]:
async for tup in ASyncSorter(self.__mapping__.items(pop=self._pop), key=_get_key, reverse=reverse):
yield self._get_from_item(tup)
async def aiterbyvalues(self, reverse: bool = False) -> ASyncIterator[T]:
async for tup in ASyncSorter(self.__mapping__.items(pop=self._pop), key=_get_value, reverse=reverse):
yield self._get_from_item(tup)
[docs]
class TaskMappingKeys(_TaskMappingView[K, K, V], Generic[K, V]):
_get_from_item = lambda self, item: _get_key(item)
[docs]
async def __aiter__(self) -> AsyncIterator[K]:
# strongref
mapping = self.__mapping__
mapping._if_pop_check_destroyed(self._pop)
yielded = set()
for key in self.__load_existing():
yielded.add(key)
# there is no chance of duplicate keys here
yield key
if mapping._init_loader is None:
await mapping._if_pop_clear(self._pop)
return
async for key in self.__load_init_loader(yielded):
yielded.add(key)
yield key
if self._pop:
# don't need to check yielded since we've been popping them as we go
for key in self.__load_existing():
yield key
await mapping._if_pop_clear(True)
else:
for key in self.__load_existing():
if key not in yielded:
yield key
def __load_existing(self) -> Iterator[K]:
# strongref
mapping = self.__mapping__
if self._pop:
for key in tuple(mapping):
mapping.pop(key)
yield key
else:
for key in tuple(mapping):
yield key
async def __load_init_loader(self, yielded: Set[K]) -> AsyncIterator[K]:
# strongref
mapping = self.__mapping__
if self._pop:
while not mapping._init_loader.done():
await mapping._wait_for_next_key()
for key in [k for k in mapping if k not in yielded]:
mapping.pop(key)
yield key
else:
while not mapping._init_loader.done():
await mapping._wait_for_next_key()
for key in [k for k in mapping if k not in yielded]:
yield key
# check for any exceptions
await mapping._init_loader
[docs]
class TaskMappingItems(_TaskMappingView[Tuple[K, V], K, V], Generic[K, V]):
_get_from_item = lambda self, item: item
[docs]
async def __aiter__(self) -> AsyncIterator[Tuple[K, V]]:
# strongref
mapping = self.__mapping__
mapping._if_pop_check_destroyed(self._pop)
if self._pop:
async for key in mapping.keys():
yield key, await mapping.pop(key)
else:
async for key in mapping.keys():
yield key, await mapping[key]
[docs]
class TaskMappingValues(_TaskMappingView[V, K, V], Generic[K, V]):
_get_from_item = lambda self, item: _get_value(item)
[docs]
async def __aiter__(self) -> AsyncIterator[V]:
# strongref
mapping = self.__mapping__
mapping._if_pop_check_destroyed(self._pop)
if self._pop:
async for key in mapping.keys():
yield await mapping.pop(key)
else:
async for key in mapping.keys():
yield await mapping[key]
__all__ = ["create_task", "TaskMapping", "TaskMappingKeys", "TaskMappingValues", "TaskMappingItems"]