Source code for y.prices.stable_swap.curve

import asyncio
import logging
from collections import defaultdict
from enum import IntEnum
from functools import cached_property
from typing import Dict, List, Optional, Tuple, TypeVar

import a_sync
import brownie
from a_sync.a_sync import HiddenMethodDescriptor
from brownie import ZERO_ADDRESS
from brownie.convert.datatypes import EthAddress
from brownie.exceptions import ContractNotFound, EventLookupError
from brownie.network.event import _EventItem
from typing_extensions import Self
from web3.exceptions import ContractLogicError

from y import ENVIRONMENT_VARIABLES as ENVS
from y import convert
from y.classes.common import ERC20, WeiBalance, _EventsLoader, _Loader
from y.constants import CHAINID
from y.contracts import Contract, contract_creation_block_async
from y.datatypes import (
    Address,
    AddressOrContract,
    AnyAddressType,
    Block,
    Pool,
    UsdPrice,
    UsdValue,
)
from y.exceptions import (
    ContractNotVerified,
    MessedUpBrownieContract,
    PriceError,
    UnsupportedNetwork,
    call_reverted,
)
from y.interfaces.curve.CurveRegistry import CURVE_REGISTRY_ABI
from y.networks import Network
from y.utils import a_sync_ttl_cache
from y.utils.events import ProcessedEvents
from y.utils.multicall import multicall_same_func_same_contract_different_inputs
from y.utils.raw_calls import raw_call

T = TypeVar("T")

logger = logging.getLogger(__name__)

# curve registry documentation https://curve.readthedocs.io/registry-address-provider.html
ADDRESS_PROVIDER = "0x0000000022D53366457F9d5E68Ec105046FC4383"
"Curve's address provider contract on all chains."

DED_POOLS = {
    Network.Mainnet: {
        "0x8301AE4fc9c624d1D396cbDAa1ed877821D7C511": "0xEd4064f376cB8d68F770FB1Ff088a3d0F3FF5c4d",
    },
}.get(CHAINID, {})
"The on chain registry no longer returns the lp token address for these dead pools, so we need to provide it manually."


[docs] class Ids(IntEnum): Main_Registry = 0 PoolInfo_Getters = 1 Exchanges = 2 Metapool_Factory = 3 Fee_Distributor = 4 CryptoSwap_Registry = 5 CryptoPool_Factory = 6 # On Mainnet, id 7 is listed as "metafactory". # On Polygon, id 7 is listed as "cryptopool factory". # On other chains, "cryptopool factory" is id 6. # On Polygon, id 6 is "crypto factory". # I've only seen this on Mainnet and Polygon so far, for now will treat `7` == `6`. Cryptopool_Factory = 7 crvUSD_Plain_Pools_deprecated_1 = 8 crvUSD_Plain_Pools_deprecated_2 = 9 crvUSD_Plain_Pools = 10 Curve_Tricrypto_Factory = 11 CurveStableswapFactoryNG = 12
_LT = TypeVar("_LT", bound=_Loader) class _CurveEventsLoader(_EventsLoader): _events: "CurveEvents"
[docs] class CurveEvents(ProcessedEvents[_EventItem]): __slots__ = ("_base",)
[docs] def __init__(self, base: _LT): super().__init__(addresses=base.address) self._base = base
[docs] class AddressProviderEvents(CurveEvents): @property def provider(self) -> "AddressProvider": return self._base def _process_event(self, event) -> None: if event.name == "NewAddressIdentifier" and event["addr"] != ZERO_ADDRESS: self.provider.identifiers[Ids(event["id"])].append(event["addr"]) elif event.name == "AddressModified" and event["new_address"] != ZERO_ADDRESS: self.provider.identifiers[Ids(event["id"])].append(event["new_address"]) logger.debug("%s loaded event %s at block %s", self, event, event.block_number) return event
[docs] class RegistryEvents(CurveEvents): __slots__ = ("_tasks",)
[docs] def __init__(self, base: _LT): super().__init__(base) self._tasks: List["asyncio.Task[EthAddress]"] = []
@property def registry(self) -> "Registry": return self._base def _process_event(self, event: _EventItem) -> None: if event.name == "PoolAdded": # TODO async this try: pool = event["pool"] except EventLookupError: pool = event["newPool"] self._tasks.append( asyncio.create_task( coro=self._add_pool(pool), name=f"Registry._add_pool for pool {pool}", ) ) curve.registries[event.address].add(pool) elif event.name == "PoolRemoved": curve.registries[event.address].discard(event["pool"]) logger.debug("%s loaded event %s at block %s", self, event, event.block_number) return event async def _add_pool(self, pool: Address) -> EthAddress: if pool in DED_POOLS: # The on chain registry no longer returns the lp token address for these dead pools, so we need to provide it manually. lp_token = DED_POOLS[pool] else: lp_token = await self.registry.contract.get_lp_token.coroutine(pool) curve.token_to_pool[lp_token] = pool async def _set_lock(self, block: int) -> None: await asyncio.gather(*self._tasks) self._tasks.clear() self._lock.set(block)
[docs] class AddressProvider(_CurveEventsLoader): __slots__ = ( "identifiers", "_events", )
[docs] def __init__(self, address: Address, *, asynchronous: bool = False): super().__init__(address, asynchronous=asynchronous) self.identifiers = defaultdict(list) self._events = AddressProviderEvents(self)
[docs] async def get_registry(self) -> EthAddress: contract = await Contract.coroutine(self.address) return await contract.get_registry
async def _load_factories(self) -> None: # factory events are quite useless, so we use a different method if debug_logs := logger.isEnabledFor(logging.DEBUG): logger._log(logging.DEBUG, "loading pools from metapool factories", ()) # TODO: remove this once curve adds to address provider if CHAINID == Network.Mainnet: self.identifiers[Ids.CurveStableswapFactoryNG] = [ "0x6A8cbed756804B16E05E741eDaBd5cB544AE21bf" ] if metapool_factories := [ Factory(factory, asynchronous=self.asynchronous) for i in [ Ids.Metapool_Factory, Ids.crvUSD_Plain_Pools, Ids.Curve_Tricrypto_Factory, Ids.CurveStableswapFactoryNG, ] for factory in self.identifiers[i] ]: async for factory, pool_list in a_sync.map( Factory.read_pools, metapool_factories ): for pool in pool_list: # for metpool factories pool is the same as lp token curve.token_to_pool[pool] = pool curve.factories[factory].add(pool) # if there are factories that haven't yet been added to the on-chain address provider, # please refer to commit 3f70c4246615017d87602e03272b3ed18d594d3c to see how to add them manually if ( identifiers := self.identifiers[Ids.CryptoPool_Factory] + self.identifiers[Ids.Cryptopool_Factory] ): if debug_logs: logger._log( logging.DEBUG, "loading pools from cryptopool factories", () ) await a_sync.map(Factory, identifiers, asynchronous=self.asynchronous) if not curve._done.is_set(): logger.info( "loaded %s pools from %s registries and %s factories", len(curve.token_to_pool), len(curve.registries), len(curve.factories), ) curve._done.set()
[docs] class Registry(_CurveEventsLoader): __slots__ = "_events"
[docs] def __init__( self, address: Address, curve: "CurveRegistry", *, asynchronous: bool = False ): super().__init__(address, asynchronous=asynchronous) self._events = RegistryEvents(self)
[docs] class Factory(_Loader):
[docs] async def get_pool(self, i: int) -> EthAddress: contract = await Contract.coroutine(self.address) return await contract.pool_list.coroutine(i)
[docs] async def pool_count(self, block: Optional[int] = None) -> int: contract = await Contract.coroutine(self.address) return await contract.pool_count.coroutine(block_identifier=block)
[docs] async def read_pools(self) -> List[EthAddress]: try: # lets load the contract async and then we can use the sync property more conveniently await Contract.coroutine(self.address) except ContractNotVerified: if CHAINID == Network.xDai: Contract.from_abi("Vyper_contract", self.address, CURVE_REGISTRY_ABI) else: # This happens sometimes, not sure why as the contract is verified. brownie.Contract.from_explorer(self.address) pool_count = await self.pool_count() return await a_sync.map(self.get_pool, range(pool_count)).values(pop=True)
async def _load(self) -> None: pool_list = await self.read_pools(sync=False) debug_logs = logger.isEnabledFor(logging.DEBUG) await asyncio.gather( *[ self.__load_pool(pool, debug_logs) for pool in pool_list if pool not in curve.factories[self.address] ] ) self._loaded.set() if debug_logs: logger._log(logging.DEBUG, "loaded %s pools for %s", (len(pool_list), self)) async def __load_pool(self, pool: Address, debug_logs: bool) -> None: factory = await Contract.coroutine(self.address) # for curve v5 pools, pool and lp token are separate if hasattr(factory, "get_token"): lp_token = await factory.get_token.coroutine(pool) elif hasattr(factory, "get_lp_token"): lp_token = await factory.get_lp_token.coroutine(pool) else: raise NotImplementedError( f"New factory {factory.address} is not yet supported. Please notify a ypricemagic maintainer." ) curve.token_to_pool[lp_token] = pool curve.factories[factory.address].add(pool) if debug_logs: logger._log(logging.DEBUG, "loaded %s pool %s", (self, pool))
[docs] class CurvePool(ERC20): """ Represents a Curve pool. This class provides methods to interact with Curve pools, including fetching pool information, calculating token exchanges, and retrieving liquidity data. Note: This class inherits from :class:`~y.classes.common.ERC20` for convenience, but a Curve pool is not always an ERC20 token. This inheritance is used to leverage existing functionality for token interactions. Methods: - :meth:`coins`: Fetches the coins in the pool. - :meth:`get_dy`: Calculates the amount of output tokens for a given input amount, simulating a token exchange. - :meth:`get_balances`: Retrieves the balances of tokens in the pool. - :meth:`get_tvl`: Retrieves the total value locked in the pool. Examples: >>> pool = CurvePool("0x1234567890abcdef1234567890abcdef12345678") >>> await pool.coins [<ERC20 TKN1 '0x...'>, <ERC20 TKN2 '0x...'>] """
[docs] @a_sync.aka.cached_property async def factory(self) -> Contract: return await curve.get_factory(self, sync=False)
__factory__: HiddenMethodDescriptor[Self, Contract]
[docs] @a_sync.aka.cached_property async def coins(self) -> List[ERC20]: """ Get coins of pool. Returns: A list of :class:`~y.classes.common.ERC20` tokens representing the coins in the pool. Examples: >>> pool = CurvePool("0x1234567890abcdef1234567890abcdef12345678") >>> await pool.coins [<ERC20 TKN1 '0x...'>, <ERC20 TKN2 '0x...'>] """ factory = await self.__factory__ if factory: coins = await factory.get_coins.coroutine(self.address) else: registry = await curve.__registry__ coins = await registry.get_coins.coroutine(self.address) # pool not in registry if set(coins) == {ZERO_ADDRESS}: coins = await multicall_same_func_same_contract_different_inputs( self.address, "coins(uint256)(address)", inputs=list(range(8)), return_None_on_failure=True, sync=False, ) return [ ERC20(coin, asynchronous=self.asynchronous) for coin in coins if coin not in {None, ZERO_ADDRESS} ]
__coins__: HiddenMethodDescriptor[Self, List[ERC20]]
[docs] @a_sync.a_sync(ram_cache_maxsize=256) async def get_coin_index(self, coin: AnyAddressType) -> int: return [i for i, _coin in enumerate(await self.__coins__) if _coin == coin][0]
[docs] @a_sync.aka.cached_property async def num_coins(self) -> int: return len(await self.__coins__)
__num_coins__: HiddenMethodDescriptor[Self, int]
[docs] async def get_dy( self, coin_ix_in: int, coin_ix_out: int, block: Optional[Block] = None, ignore_pools: Tuple[Pool, ...] = (), skip_cache: bool = ENVS.SKIP_CACHE, ) -> Optional[WeiBalance]: tokens = await self.__coins__ token_in: ERC20 = tokens[coin_ix_in] token_out: ERC20 = tokens[coin_ix_out] amount_in = await token_in.__scale__ contract = await Contract.coroutine(self.address) try: amount_out = await contract.get_dy.coroutine( coin_ix_in, coin_ix_out, amount_in, block_identifier=block ) return WeiBalance( amount_out, token_out, block=block, ignore_pools=(*ignore_pools, self), skip_cache=skip_cache, ) except Exception as e: if call_reverted(e): return None raise
[docs] @a_sync.aka.cached_property async def coins_decimals(self) -> List[int]: factory = await self.__factory__ source = factory or await curve.registry coins_decimals = await source.get_decimals.coroutine(self.address) # pool not in registry if not any(coins_decimals): coins_decimals = await a_sync.map( ERC20.decimals, await self.__coins__ ).values(pop=True) return [dec for dec in coins_decimals if dec != 0]
__coins_decimals__: HiddenMethodDescriptor[Self, List[int]]
[docs] @a_sync.aka.cached_property async def get_underlying_coins(self) -> List[ERC20]: factory = await self.__factory__ if factory: # new factory reverts for non-meta pools if not hasattr(factory, "is_meta") or factory.is_meta(self.address): coins = await factory.get_underlying_coins.coroutine(self.address) else: coins = await factory.get_coins.coroutine(self.address) else: registry = await curve.registry coins = await registry.get_underlying_coins.coroutine(self.address) # pool not in registry, not checking for underlying_coins here if set(coins) == {ZERO_ADDRESS}: return await self.__coins__ return [ ERC20(coin, asynchronous=self.asynchronous) for coin in coins if coin != ZERO_ADDRESS ]
__get_underlying_coins__: HiddenMethodDescriptor[Self, List[ERC20]]
[docs] @a_sync.a_sync(ram_cache_maxsize=1000) async def get_balances( self, block: Optional[Block] = None, skip_cache: bool = ENVS.SKIP_CACHE ) -> List[WeiBalance]: """ Get {token: balance} of liquidity in the pool. Args: block: The block number to query. Defaults to the latest block. skip_cache: If True, skip using the cache while fetching balance data. Returns: A list of :class:`~y.classes.common.WeiBalance` objects representing the balances of tokens in the pool. Examples: >>> pool = CurvePool("0x1234567890abcdef1234567890abcdef12345678") >>> await pool.get_balances() [<WeiBalance token=<ERC20 TKN1 '0x...'> balance=1000000000000000000 block=None>, ...] """ coins = await self.__coins__ try: factory = await self.__factory__ source = factory or await curve.__registry__ balances = await source.get_balances.coroutine( self.address, block_identifier=block ) except (ContractLogicError, ValueError): # ContractLogicError in web3>=6.0, ValueError in <6.0 # fallback for historical queries where registry was not yet deployed balances = await a_sync.map( self._get_balance, range(len(coins)), block=block ).values(pop=True) if not any(balances): raise ValueError(f"could not fetch balances {self.__str__()} at {block}") return [ WeiBalance(balance, coin, block, skip_cache=skip_cache) for coin, balance in zip(coins, balances) if coin != ZERO_ADDRESS ]
async def _get_balance( self, i: int, block: Optional[Block] = None ) -> Optional[int]: try: contract = await Contract.coroutine(self.address) except ContractNotVerified: # TODO: figure out if we need to build this, usually they get verified quickly return None try: return await contract.balances.coroutine(i, block_identifier=block) except ContractLogicError as e: # happens on web3py>=6.0 if str(e) == "execution reverted": return None raise except ValueError as e: # happens on web3py<6.0 if str(e) == "No data was returned - the call likely reverted": return None raise
[docs] async def get_tvl( self, block: Optional[Block] = None, skip_cache: bool = ENVS.SKIP_CACHE ) -> Optional[UsdValue]: """ Get total value in Curve pool. Args: block: The block number to query. Defaults to the latest block. skip_cache: If True, skip using the cache while fetching TVL data. Returns: The total value locked in the pool as a :class:`~y.datatypes.UsdValue`, or None if the TVL cannot be determined. Examples: >>> pool = CurvePool("0x1234567890abcdef1234567890abcdef12345678") >>> await pool.get_tvl() UsdValue(1234567.89) """ try: price = await WeiBalance.value_usd.sum( self.get_balances(block=block, skip_cache=skip_cache, sync=False), sync=False, ) return UsdValue(price) except ValueError as e: if str(e).startswith("could not fetch balances "): return None raise
[docs] @a_sync.a_sync(ram_cache_maxsize=100_000, ram_cache_ttl=60 * 60) async def check_liquidity(self, token: Address, block: Block) -> int: if block < await contract_creation_block_async(self.address): return 0 index = await self.get_coin_index(token, sync=False) if balance := await self._get_balance(index, block): return balance return 0
[docs] class CurveRegistry(a_sync.ASyncGenericSingleton): __slots__ = ("__task",)
[docs] def __init__(self, *, asynchronous: bool = False) -> None: super().__init__() self.asynchronous = asynchronous try: self.address_provider = AddressProvider( ADDRESS_PROVIDER, asynchronous=self.asynchronous ) self.address_provider.contract except (ContractNotFound, ContractNotVerified) as e: raise UnsupportedNetwork("curve is not supported on this network") from e except MessedUpBrownieContract as e: if CHAINID == Network.Cronos: raise UnsupportedNetwork( "curve is not supported on this network" ) from e else: raise self.registries = defaultdict(set) # registry -> pools self.factories = defaultdict(set) # factory -> pools self.pools = set() self.token_to_pool = dict() # lp_token -> pool
def __repr__(self) -> str: return "<CurveRegistry>" @property def identifiers(self) -> List[EthAddress]: return self.address_provider.identifiers
[docs] @a_sync.aka.cached_property async def registry(self) -> Contract: try: return await Contract.coroutine(self.identifiers[0][-1]) except IndexError: # if we couldn't get the registry via logs return await Contract.coroutine( await raw_call( self.address_provider, "get_registry()", output="address", sync=False, ) )
__registry__: HiddenMethodDescriptor[Self, Contract]
[docs] async def load_all(self) -> None: await self._done.wait()
[docs] async def get_factory(self, pool: AddressOrContract) -> Contract: """ Get metapool factory that has spawned a pool. Args: pool: The address or contract of the pool. Returns: The :class:`~y.contracts.Contract` representing the factory. Examples: >>> factory = await curve.get_factory("0x1234567890abcdef1234567890abcdef12345678") >>> print(factory) <Contract '0x...'> """ try: factory = next( factory for factory, factory_pools in self.factories.items() if str(pool) in factory_pools ) return await Contract.coroutine(factory) except StopIteration: return None
[docs] @a_sync_ttl_cache async def get_price( self, token: Address, block: Optional[Block] = None, skip_cache: bool = ENVS.SKIP_CACHE, ) -> Optional[float]: pool: CurvePool = await self.get_pool(token, sync=False) if pool is None: return None tvl = await pool.get_tvl(block=block, skip_cache=skip_cache, sync=False) if tvl is None: return None return tvl / await ERC20(token, asynchronous=True).total_supply_readable(block)
[docs] @a_sync.a_sync(cache_type="memory") async def get_pool(self, token: AnyAddressType) -> CurvePool: """ Get Curve pool (swap) address by LP token address. Supports factory pools. Args: token: The address of the LP token. Returns: The :class:`~y.prices.stable_swap.curve.CurvePool` associated with the LP token. Examples: >>> pool = await curve.get_pool("0x1234567890abcdef1234567890abcdef12345678") >>> print(pool) <CurvePool '0x...'> """ await self.load_all() token = await convert.to_address_async(token) if token in self.token_to_pool and token != ZERO_ADDRESS: return CurvePool(self.token_to_pool[token], asynchronous=self.asynchronous)
[docs] @a_sync.a_sync(cache_type="memory") async def get_price_for_underlying( self, token_in: Address, block: Optional[Block] = None, ignore_pools: Tuple[Pool, ...] = (), skip_cache: bool = ENVS.SKIP_CACHE, ) -> Optional[UsdPrice]: try: pools = (await self.__coin_to_pools__)[token_in] except KeyError: return None for pool in ignore_pools: try: pools.remove(pool) except ValueError: continue if pools and block is not None: pools = [ pool async for pool, deploy_block in CurvePool.deploy_block.map( pools, when_no_history_return_0=True ) if deploy_block <= block ] if not pools: return None # Choose a pool to use for pricing `token_in`. elif len(pools) == 1: pool = pools[0] else: # Use the pool with deepest liquidity. deepest_pool, deepest_bal = None, 0 async for pool, depth in CurvePool.check_liquidity.map( pools, token=token_in, block=block ).map(): if depth > deepest_bal: deepest_pool = pool deepest_bal = depth pool = deepest_pool if pool is None: return None if len(await pool.__coins__) != 2: # TODO: handle this sitch if necessary return # Get the price for `token_in` using the selected pool. # this works for most typical metapools token_in_ix = await pool.get_coin_index(token_in, sync=False) token_out_ix = 0 if token_in_ix == 1 else 1 if token_in_ix == 0 else None dy: Optional[WeiBalance] = await pool.get_dy( token_in_ix, token_out_ix, block=block, ignore_pools=ignore_pools, skip_cache=skip_cache, sync=False, ) if dy is None: return None try: return await dy.__value_usd__ except PriceError as e: logger.debug("%s for %s at block %s", e.__class__.__name__, token_in, block) return None
[docs] @a_sync.aka.cached_property async def coin_to_pools(self) -> Dict[str, List[CurvePool]]: mapping = defaultdict(set) await self.load_all() for pool in { CurvePool(pool) for pools in self.factories.values() for pool in pools }: for coin in await pool.__coins__: mapping[coin].add(pool) return {coin: list(pools) for coin, pools in mapping.items()}
__coin_to_pools__: HiddenMethodDescriptor[Self, Dict[str, List[CurvePool]]]
[docs] async def check_liquidity( self, token: Address, block: Block, ignore_pools: Tuple[Pool, ...] ) -> int: if pools := [ pool for pool in (await self.__coin_to_pools__).get(token, []) if pool not in ignore_pools ]: return await CurvePool.check_liquidity.max( pools, token=token, block=block, sync=False ) return 0
@cached_property def _done(self) -> a_sync.Event: """A helper function to ensure the Event is attached to the correct loop.""" self._task return a_sync.Event(name="curve") @cached_property def _task(self) -> asyncio.Task: logger.debug("creating loader task for %s", self) task = asyncio.create_task(coro=self._load_all(), name=f"{self}._load_all()") def done_callback(t: asyncio.Task): if e := t.exception(): logger.error("exception while loading %s: %s", self, e) logger.exception(e) self.__task = None raise e task.add_done_callback(done_callback) return task async def _load_all(self) -> None: await self.address_provider logger.debug( "curve address provider events loaded, now loading factories and pools" ) # NOTE: Gnosis chain's address provider fails to provide registry via events. Maybe other chains as well. if ( not self.identifiers[Ids.Main_Registry] and (registry := await self.address_provider.get_registry()) != ZERO_ADDRESS ): self.identifiers[Ids.Main_Registry] = [registry] while True: # Check if any registries were updated, then ensure all old and new are loaded if registries := [ self.identifiers[i][-1] for i in [Ids.Main_Registry, Ids.CryptoSwap_Registry] if self.identifiers[i] ]: await a_sync.map( Registry, registries, curve=self, asynchronous=self.asynchronous ) # load metapool and curve v5 factories await self.address_provider._load_factories() await asyncio.sleep(600)
try: curve: CurveRegistry = CurveRegistry(asynchronous=True) except UnsupportedNetwork: curve = set()