Source code for a_sync.a_sync.modifiers.semaphores

# mypy: disable-error-code=valid-type
# mypy: disable-error-code=misc
import asyncio
import functools

from a_sync import exceptions, primitives
from a_sync._typing import *

# We keep this here for now so we don't break downstream deps. Eventually will be removed.
from a_sync.primitives import ThreadsafeSemaphore, DummySemaphore


@overload
def apply_semaphore(  # type: ignore [misc]
    semaphore: SemaphoreSpec,
) -> AsyncDecorator[P, T]:
    """Create a decorator to apply a semaphore to a coroutine function.

    This overload is used when the semaphore is provided as a single argument,
    returning a decorator that can be applied to a coroutine function.

    Args:
        semaphore (Union[int, asyncio.Semaphore, primitives.Semaphore]):
            The semaphore to apply, which can be an integer, an `asyncio.Semaphore`, or a `primitives.Semaphore`.

    Examples:
        Using as a decorator with an integer semaphore:
        >>> @apply_semaphore(2)
        ... async def limited_concurrent_function():
        ...     pass

        Using as a decorator with an `asyncio.Semaphore`:
        >>> sem = asyncio.Semaphore(2)
        >>> @apply_semaphore(sem)
        ... async def another_function():
        ...     pass

        Using as a decorator with a `primitives.Semaphore`:
        >>> sem = primitives.ThreadsafeSemaphore(2)
        >>> @apply_semaphore(sem)
        ... async def yet_another_function():
        ...     pass

    See Also:
        - :class:`asyncio.Semaphore`
        - :class:`primitives.Semaphore`

    Note:
        `primitives.Semaphore` is a subclass of `asyncio.Semaphore`. Therefore, when the documentation refers to `asyncio.Semaphore`, it also includes `primitives.Semaphore` and any other subclasses.
    """


@overload
def apply_semaphore(
    coro_fn: CoroFn[P, T],
    semaphore: SemaphoreSpec,
) -> CoroFn[P, T]:
    """Apply a semaphore directly to a coroutine function.

    This overload is used when both the coroutine function and semaphore are provided,
    directly applying the semaphore to the coroutine function.

    Args:
        coro_fn (Callable): The coroutine function to which the semaphore will be applied.
        semaphore (Union[int, asyncio.Semaphore, primitives.Semaphore]):
            The semaphore to apply, which can be an integer, an `asyncio.Semaphore`, or a `primitives.Semaphore`.

    Examples:
        Applying directly to a function with an integer semaphore:
        >>> async def my_coroutine():
        ...     pass
        >>> my_coroutine = apply_semaphore(my_coroutine, 3)

        Applying directly with an `asyncio.Semaphore`:
        >>> sem = asyncio.Semaphore(3)
        >>> my_coroutine = apply_semaphore(my_coroutine, sem)

        Applying directly with a `primitives.Semaphore`:
        >>> sem = primitives.ThreadsafeSemaphore(3)
        >>> my_coroutine = apply_semaphore(my_coroutine, sem)

    See Also:
        - :class:`asyncio.Semaphore`
        - :class:`primitives.Semaphore`

    Note:
        `primitives.Semaphore` is a subclass of `asyncio.Semaphore`. Therefore, when the documentation refers to `asyncio.Semaphore`, it also includes `primitives.Semaphore` and any other subclasses.
    """


[docs] def apply_semaphore( coro_fn: Optional[Union[CoroFn[P, T], SemaphoreSpec]] = None, semaphore: SemaphoreSpec = None, ) -> AsyncDecoratorOrCoroFn[P, T]: """Apply a semaphore to a coroutine function or return a decorator. This function can be used to apply a semaphore to a coroutine function either by passing the coroutine function and semaphore as arguments or by using the semaphore as a decorator. It raises exceptions if the inputs are not valid. Args: coro_fn (Optional[Callable]): The coroutine function to which the semaphore will be applied, or None if the semaphore is to be used as a decorator. semaphore (Union[int, asyncio.Semaphore, primitives.Semaphore]): The semaphore to apply, which can be an integer, an `asyncio.Semaphore`, or a `primitives.Semaphore`. Raises: ValueError: If `coro_fn` is an integer or `asyncio.Semaphore` and `semaphore` is not None. exceptions.FunctionNotAsync: If the provided function is not a coroutine. TypeError: If the semaphore is not an integer, an `asyncio.Semaphore`, or a `primitives.Semaphore`. Examples: Using as a decorator: >>> @apply_semaphore(2) ... async def limited_concurrent_function(): ... pass Applying directly to a function: >>> async def my_coroutine(): ... pass >>> my_coroutine = apply_semaphore(my_coroutine, 3) Handling invalid inputs: >>> try: ... apply_semaphore(3, 2) ... except ValueError as e: ... print(e) See Also: - :class:`asyncio.Semaphore` - :class:`primitives.Semaphore` Note: `primitives.Semaphore` is a subclass of `asyncio.Semaphore`. Therefore, when the documentation refers to `asyncio.Semaphore`, it also includes `primitives.Semaphore` and any other subclasses. """ # Parse Inputs if isinstance(coro_fn, (int, asyncio.Semaphore)): if semaphore is not None: raise ValueError("You can only pass in one arg.") semaphore = coro_fn coro_fn = None elif not asyncio.iscoroutinefunction(coro_fn): raise exceptions.FunctionNotAsync(coro_fn) # Create the semaphore if necessary if isinstance(semaphore, int): semaphore = primitives.ThreadsafeSemaphore(semaphore) elif not isinstance(semaphore, asyncio.Semaphore): raise TypeError( f"'semaphore' must either be an integer or a Semaphore object. You passed {semaphore}" ) # Create and return the decorator if isinstance(semaphore, primitives.Semaphore): # NOTE: Our `Semaphore` primitive can be used as a decorator. # While you can use it the `async with` way like any other semaphore and we could make this code section cleaner, # applying it as a decorator adds some useful info to its debug logs so we do that here if we can. semaphore_decorator = semaphore else: def semaphore_decorator(coro_fn: CoroFn[P, T]) -> CoroFn[P, T]: @functools.wraps(coro_fn) async def semaphore_wrap(*args, **kwargs) -> T: async with semaphore: # type: ignore [union-attr] return await coro_fn(*args, **kwargs) return semaphore_wrap return semaphore_decorator if coro_fn is None else semaphore_decorator(coro_fn)
dummy_semaphore = primitives.DummySemaphore() """A dummy semaphore that does not enforce any concurrency limits."""