Source code for y.classes.common


import abc
import asyncio
from contextlib import suppress
from decimal import Decimal
from functools import cached_property
from logging import getLogger
from typing import (TYPE_CHECKING, Any, Awaitable, Generator, Literal, NoReturn,
                    Optional, Tuple, Union, final)

import a_sync
from a_sync.a_sync import HiddenMethodDescriptor, HiddenMethod
from a_sync.a_sync.method import ASyncBoundMethod
from brownie import Contract, chain, web3
from brownie.convert.datatypes import HexString
from brownie.exceptions import ContractNotFound
from eth_retry import auto_retry
from typing_extensions import Self

from y import ENVIRONMENT_VARIABLES as ENVS
from y import convert
from y._decorators import stuck_coro_debugger
from y.classes.singleton import ChecksumASyncSingletonMeta
from y.constants import EEE_ADDRESS
from y.contracts import (Contract, build_name, contract_creation_block_async,
                         has_method, probe)
from y.datatypes import Address, AnyAddressType, Block, Pool, UsdPrice
from y.exceptions import (ContractNotVerified, MessedUpBrownieContract,
                          NonStandardERC20)
from y.networks import Network
from y.utils import _erc20, logging, raw_calls

if TYPE_CHECKING:
    from y.utils.events import Events

logger = getLogger(__name__)

[docs] def hex_to_string(h: HexString) -> str: '''returns a string from a HexString''' h = h.hex().rstrip("0") if len(h) % 2 != 0: h += "0" return bytes.fromhex(h).decode("utf-8")
[docs] class ContractBase(a_sync.ASyncGenericBase, metaclass=ChecksumASyncSingletonMeta): # defaults are stored as class vars to keep instance dicts smaller asynchronous: bool = False _deploy_block: Optional[int] = None __slots__ = "address",
[docs] def __init__( self, address: AnyAddressType, asynchronous: bool = False, _deploy_block: Optional[int] = None, ) -> None: self.address = convert.to_address(address) if asynchronous: self.asynchronous = asynchronous if _deploy_block: self._deploy_block = _deploy_block super().__init__()
[docs] def __str__(self) -> str: return f'{self.address}'
def __repr__(self) -> str: return f"<{self.__class__.__name__} '{self.address}'"
[docs] def __eq__(self, __o: object) -> bool: if isinstance(__o, (ContractBase, Contract)): return __o.address == self.address # Skip checksumming if applicable, its computationally expensive # NOTE: We assume a mixed-case address is checksummed. If it isn't, wtf are you doing? elif isinstance(__o, str) and __o != __o.lower() and __o != __o.upper(): return __o == self.address try: return convert.to_address(__o) == self.address except Exception: return False
def __hash__(self) -> int: return hash(self.address) @property def contract(self) -> Contract: return Contract(self.address) @cached_property def _is_cached(self) -> bool: try: self.contract return True except (ContractNotVerified): return False except (ContractNotFound, MessedUpBrownieContract): return None
[docs] @a_sync.aka.cached_property @stuck_coro_debugger async def build_name(self) -> str: return await build_name(self.address, sync=False)
__build_name__: HiddenMethodDescriptor[Self, str]
[docs] @stuck_coro_debugger async def deploy_block(self, when_no_history_return_0: bool = False) -> int: if self._deploy_block is None: self._deploy_block = await contract_creation_block_async(self.address, when_no_history_return_0=when_no_history_return_0) return self._deploy_block
deploy_block: ASyncBoundMethod[Self, Any, int]
[docs] async def has_method(self, method: str, return_response: bool = False) -> Union[bool,Any]: return await has_method(self.address, method, return_response=return_response, sync=False)
[docs] class ERC20(ContractBase): def __repr__(self) -> str: cls = type(self).__name__ with suppress(AttributeError): if ERC20.symbol.has_cache_value(self): symbol = ERC20.symbol.get_cache_value(self) return f"<{cls} {symbol} '{self.address}'>" elif not asyncio.get_event_loop().is_running() and not self.asynchronous: try: return f"<{cls} {self.__symbol__(sync=True)} '{self.address}'>" except NonStandardERC20: return f"<{cls} SYMBOL_INVALID '{self.address}'>" return f"<{cls} SYMBOL_NOT_LOADED '{self.address}'>"
[docs] @a_sync.aka.cached_property @stuck_coro_debugger async def symbol(self) -> str: if self.address == EEE_ADDRESS: return { Network.Mainnet: "ETH", Network.Fantom: "FTM", Network.Polygon: "MATIC", Network.Arbitrum: "ETH", Network.Optimism: "ETH", Network.Base: "ETH", }.get(chain.id, "ETH") import y._db.utils.token as db if symbol := await db.get_symbol(self.address): return symbol symbol = await self._symbol() db.set_symbol(self.address, symbol) return symbol
[docs] @a_sync.aka.property @stuck_coro_debugger async def name(self) -> str: if self.address == EEE_ADDRESS: return "Ethereum" import y._db.utils.token as db name = await db.get_name(self.address) if name: return name name = await self._name() db.set_name(self.address, name) return name
[docs] @a_sync.aka.cached_property @stuck_coro_debugger async def decimals(self) -> int: if self.address == EEE_ADDRESS: return 18 import y._db.utils.token as db return await db.get_decimals(self.address)
[docs] @a_sync.a_sync # Override the leading underscore so a_sync lib doesn't bypass this fn async def _decimals(self, block: Optional[Block] = None) -> int: '''used to fetch decimals at specific block''' if self.address == EEE_ADDRESS: return 18 return await _erc20.decimals(self.address, block=block, sync=False)
[docs] @a_sync.aka.cached_property @stuck_coro_debugger async def scale(self) -> int: return 10 ** await self.__decimals__(asynchronous=True)
[docs] @a_sync.a_sync # Override the leading underscore so a_sync lib doesn't bypass this fn async def _scale(self, block: Optional[Block] = None) -> int: return 10 ** await self._decimals(block, sync=False)
[docs] async def total_supply(self, block: Optional[Block] = None) -> int: """ Get the total supply of the token. Args: block (optional): The block number to query. Defaults to latest block. Returns: The total supply of the token. """ return await _erc20.totalSupply(self.address, block=block, sync=False)
[docs] async def total_supply_readable(self, block: Optional[Block] = None) -> float: total_supply, scale = await asyncio.gather(self.total_supply(block=block, sync=False), self.__scale__) return total_supply / scale
[docs] async def balance_of(self, address: AnyAddressType, block: Optional[Block] = None) -> int: return await raw_calls.balanceOf(self.address, address, block=block, sync=False)
[docs] async def balance_of_readable(self, address: AnyAddressType, block: Optional[Block] = None) -> float: balance, scale = await asyncio.gather(self.balance_of(address, block=block, asynchronous=True), self.__scale__(asynchronous=True)) return balance / scale
[docs] async def price( self, block: Optional[Block] = None, return_None_on_failure: bool = False, skip_cache: bool = ENVS.SKIP_CACHE, ignore_pools: Tuple[Pool, ...] = (), ) -> Optional[UsdPrice]: from y.prices.magic import get_price return await get_price( self.address, block=block, fail_to_None=return_None_on_failure, skip_cache=skip_cache, ignore_pools=ignore_pools, sync=False, )
async def _symbol(self) -> str: symbol = await probe(self.address, ["symbol()(string)", "SYMBOL()(string)", "getSymbol()(string)"]) if symbol is None: # Sometimes the above will fail if the symbol method returns bytes32, as with MKR. Let's try this. symbol = await probe(self.address, ["symbol()(bytes32)"]) if symbol: symbol = hex_to_string(symbol) if symbol: return symbol # we've failed to fetch self.__raise_exception('symbol') async def _name(self) -> str: name = await probe(self.address, ["name()(string)", "NAME()(string)", "getName()(string)"]) if name is None: # Sometimes the above will fail if the name method returns bytes32, as with MKR. Let's try this. name = await probe(self.address, ["name()(bytes32)"]) if name: name = hex_to_string(name) if name: return name # we've failed to fetch self.__raise_exception('name') def __raise_exception(self, fn_name: str): raise NonStandardERC20(f''' Unable to fetch `{fn_name}` for {self.address} on {Network.printable()} If the contract is verified, please check to see if it has a strangely named `{fn_name}` method and create an issue on https://github.com/BobTheBuidler/ypricemagic with the contract address and correct method name so we can keep things going smoothly :)''') from None # These dundermethods are created by a_sync for the async_properties on this class __symbol__: HiddenMethodDescriptor[Self, str] __name__: HiddenMethodDescriptor[Self, str] __decimals__: HiddenMethodDescriptor[Self, int] __scale__: HiddenMethodDescriptor[Self, int]
[docs] @final class WeiBalance(a_sync.ASyncGenericBase): # defaults are stored as class vars to keep instance dicts smaller block: Optional[Block] = None asynchronous: bool = False _skip_cache: bool = ENVS.SKIP_CACHE _ignore_pools: Tuple[Pool, ...] = ()
[docs] def __init__( self, balance: int, token: AnyAddressType, block: Optional[Block] = None, *, skip_cache: bool = ENVS.SKIP_CACHE, ignore_pools: Tuple[Pool, ...] = (), asynchronous: bool = False, ) -> None: if asynchronous != self.asynchronous: self.asynchronous = asynchronous self.balance = Decimal(balance) self.token = ERC20(str(token), asynchronous=self.asynchronous) if block != self.block: self.block = block super().__init__() if skip_cache != self._skip_cache: self._skip_cache = skip_cache if ignore_pools != self._ignore_pools: self._ignore_pools = ignore_pools
[docs] def __str__(self) -> str: return str(self.balance)
def __repr__(self) -> str: return f"<WeiBalance token={self.token} balance={self.balance} block={self.block}>" def __hash__(self) -> int: return hash((self.balance, self.token, self.block, self._skip_cache, self._ignore_pools))
[docs] def __bool__(self) -> bool: return bool(self.balance)
[docs] def __eq__(self, __o: object) -> bool: if isinstance(__o, int): return __o == self.balance elif isinstance(__o, WeiBalance): return ( self.balance == __o.balance and self.token == __o.token and self.block == __o.block and self._skip_cache == __o._skip_cache and self._ignore_pools == __o._ignore_pools ) return False
[docs] def __lt__(self, __o: object) -> bool: if isinstance(__o, int): return __o < self.balance elif isinstance(__o, WeiBalance): if self.token != __o.token: raise ValueError(f"'<' only supported between {self.__class__.__name__} instances when denominated in the same token.") from None return self.balance < __o.balance raise TypeError(f"'<' not supported between instances of '{self.__class__.__name__}' and '{__o.__class__.__name__}'") from None
[docs] def __ge__(self, __o: object) -> bool: if __o < self: return True elif type(__o) is type(self): return self == __o raise TypeError(f"'>=' not supported between instances of '{self.__class__.__name__}' and '{__o.__class__.__name__}'") from None
[docs] def __radd__(self, __o: Union["WeiBalance", Literal[0]]) -> "WeiBalance": if __o == 0: return self try: if self.token != __o.token: raise ValueError(f"addition not supported between instances of '{self.__class__.__name__}' and '{__o.__class__.__name__}'") from None if self.block != __o.block: raise ValueError("addition not supported between balances at different block heights") from None if self._skip_cache != __o._skip_cache: raise ValueError("addition not supported between balances with different `_skip_cache` values") from None if self._ignore_pools != __o._ignore_pools: raise ValueError("addition not supported between balances with different `_ignore_pools` values") from None return WeiBalance(self.balance - __o.balance, self.token, self.block, skip_cache=self._skip_cache, ignore_pools=self._ignore_pools) except AttributeError: raise TypeError(f"right addition not supported between instances of '{type(self).__name__}' and '{type(__o).__name__}'") from None
[docs] def __add__(self, __o: "WeiBalance") -> "WeiBalance": try: if self.token != __o.token: raise ValueError(f"addition not supported between instances of '{self.__class__.__name__}' and '{__o.__class__.__name__}'") from None if self.block != __o.block: raise ValueError("addition not supported between balances at different block heights") from None if self._skip_cache != __o._skip_cache: raise ValueError("addition not supported between balances with different `_skip_cache` values") from None if self._ignore_pools != __o._ignore_pools: raise ValueError("addition not supported between balances with different `_ignore_pools` values") from None return WeiBalance(self.balance - __o.balance, self.token, self.block, skip_cache=self._skip_cache, ignore_pools=self._ignore_pools) except AttributeError: raise TypeError(f"addition not supported between instances of '{type(self).__name__}' and '{type(__o).__name__}'") from None
[docs] def __sub__(self, __o: "WeiBalance") -> "WeiBalance": try: if self.token != __o.token: raise ValueError(f"subtraction not supported between instances of '{self.__class__.__name__}' and '{__o.__class__.__name__}'") from None if self.block != __o.block: raise ValueError("subtraction not supported between balances at different block heights") from None if self._skip_cache != __o._skip_cache: raise ValueError("subtraction not supported between balances with different `_skip_cache` values") from None if self._ignore_pools != __o._ignore_pools: raise ValueError("subtraction not supported between balances with different `_ignore_pools` values") from None return WeiBalance(self.balance - __o.balance, self.token, self.block, skip_cache=self._skip_cache, ignore_pools=self._ignore_pools) except AttributeError: raise TypeError(f"subtraction not supported between instances of '{type(self).__name__}' and '{type(__o).__name__}'") from None
[docs] def __mul__(self, __o: Union[int, float, Decimal]) -> "WeiBalance": if not isinstance(__o, (int, float, Decimal)): raise TypeError(f"multiplication not supported between instances of '{type(self).__name__}' and '{type(__o).__name__}'") from None return WeiBalance(self.balance * Decimal(__o), self.token, self.block, skip_cache=self._skip_cache, ignore_pools=self._ignore_pools)
[docs] def __truediv__(self, __o: Union[int, float, Decimal]) -> "WeiBalance": if not isinstance(__o, (int, float, Decimal)): raise TypeError(f"division not supported between instances of '{type(self).__name__}' and '{type(__o).__name__}'") from None return WeiBalance(self.balance / Decimal(__o), self.token, self.block, skip_cache=self._skip_cache, ignore_pools=self._ignore_pools)
[docs] @a_sync.aka.property async def readable(self) -> Decimal: if self.balance == 0: return 0 scale = await self.token.__scale__ readable = self.balance / scale self._logger.debug("balance: %s decimals: %s readable: %s", self.balance, str(scale).count("0"), readable) return readable
__readable__: HiddenMethodDescriptor[Self, Decimal]
[docs] @a_sync.aka.property async def price(self) -> Decimal: price = Decimal(await self.token.price(block=self.block, skip_cache=self._skip_cache, ignore_pools=self._ignore_pools, sync=False)) self._logger.debug("balance: %s price: %s", self, price) return price
__price__: HiddenMethodDescriptor[Self, Decimal]
[docs] @a_sync.aka.property async def value_usd(self) -> Decimal: if self.balance == 0: return 0 balance, price = await asyncio.gather(self.__readable__, self.__price__) value = balance * price self._logger.debug("balance: %s price: %s value: %s", balance, price, value) return value
__value_usd__: HiddenMethodDescriptor[Self, Decimal] @cached_property def _logger(self) -> logging.logging.Logger: return logging.get_price_logger(self.token.address, self.block, extra=self.__class__.__name__)
class _Loader(ContractBase): """Used for use cases where you need to load data thru present time before proceeding, and then continue loading data in the background.""" __slots__ = "_loaded", "_init_block", "__exc", "__task", def __init__(self, address: Address, asynchronous: bool = False): super().__init__(address, asynchronous=asynchronous) self._init_block = auto_retry(web3.eth.get_block_number)() self._loaded = None self.__exc = None self.__task = None def __await__(self) -> Generator[Any, None, Literal[True]]: """Returns `True` once the `_Loader` has loaded all relevant data thru the current block""" return self.loaded.__await__() @abc.abstractmethod async def _load(self) -> NoReturn: """ `self._load` is the coro that will run in the daemon task associated with this _Loader. Your implementation MUST set Event `self._loaded` once data has been loaded thru the current block, or it will hang indefinitely. """ @property def loaded(self) -> Awaitable[Literal[True]]: """Returns `True` once the `_Loader` has loaded all relevant data thru the current block""" self._task # ensure task is running and not errd if self._loaded is None: self._loaded = a_sync.Event(name=self) return self._loaded.wait() @property def _task(self) -> "asyncio.Task[NoReturn]": """The task that runs `self._load() for this `_Loader`""" if self.__exc: # create a new duplicate exc instead of building a massive traceback on the original raise type(self.__exc)(*self.__exc.args).with_traceback(self.__tb) if self.__task is None: logger.debug("creating loader task for %s", self) self.__task = asyncio.create_task(coro=self.__load(), name=f"{self}.__load()") self.__task.add_done_callback(self._done_callback) return self.__task def _done_callback(self, task: "asyncio.Task[Any]") -> None: """called on `self._task` when it completes, if applicable""" if e := task.exception(): logger.error("exception while loading %s: %s", self, e) logger.exception(e) self.__task = None async def __load(self) -> NoReturn: """loads the loader and catches any exceptions""" try: await self._load() except Exception as e: import traceback self.__exc = e self.__tb = e.__traceback__ # no need to hold vars in memory traceback.clear_frames(self.__tb) raise class _EventsLoader(_Loader): """Used for use cases where you need to load event data thru present time before proceeding, and then continue loading data in the background.""" @abc.abstractproperty def _events(self) -> "Events": ... @property def loaded(self) -> Awaitable[Literal[True]]: self._task # ensure task is running and not err'd if self._loaded is None: self._loaded = asyncio.ensure_future(self._events._lock.wait_for(self._init_block)) return self._loaded async def _load(self) -> NoReturn: # TODO: extend this for constant loading async for _ in self._events.events(self._init_block): pass