Source code for dao_treasury.streams.llamapay

import asyncio
import datetime as dt
import decimal
from logging import getLogger
from typing import (
    Awaitable,
    Callable,
    Dict,
    Final,
    Iterator,
    List,
    Optional,
    Set,
    final,
)

import dank_mids
import pony.orm
from a_sync import AsyncThreadPoolExecutor, igather
from brownie.network.event import _EventItem
from eth_typing import BlockNumber, ChecksumAddress, HexAddress, HexStr
from tqdm.asyncio import tqdm_asyncio

import y
from y.time import NoBlockFound, UnixTimestamp
from y.utils.events import decode_logs, get_logs_asap

from dao_treasury import constants
from dao_treasury.db import (
    Stream,
    StreamedFunds,
    Address,
    Token,
    must_sort_outbound_txgroup_dbid,
)
from dao_treasury._wallet import TreasuryWallet


logger: Final = getLogger(__name__)

_UTC: Final = dt.timezone.utc

_ONE_DAY: Final = 60 * 60 * 24

_STREAMS_THREAD: Final = AsyncThreadPoolExecutor(1)

create_task: Final = asyncio.create_task
sleep: Final = asyncio.sleep

datetime: Final = dt.datetime
timedelta: Final = dt.timedelta
fromtimestamp: Final = datetime.fromtimestamp
now: Final = datetime.now

Decimal: Final = decimal.Decimal

ObjectNotFound: Final = pony.orm.ObjectNotFound
commit: Final = pony.orm.commit
db_session: Final = pony.orm.db_session

Contract: Final = y.Contract
Network: Final = y.Network
get_block_at_timestamp: Final = y.get_block_at_timestamp
get_price: Final = y.get_price


networks: Final = [Network.Mainnet]

factories: List[HexAddress] = []

if dai_stream_factory := {
    Network.Mainnet: "0x60c7B0c5B3a4Dc8C690b074727a17fF7aA287Ff2",
}.get(constants.CHAINID):
    factories.append(dai_stream_factory)

if yfi_stream_factory := {
    Network.Mainnet: "0xf3764eC89B1ad20A31ed633b1466363FAc1741c4",
}.get(constants.CHAINID):
    factories.append(yfi_stream_factory)


def _generate_dates(
    start: dt.datetime, end: dt.datetime, stop_at_today: bool = True
) -> Iterator[dt.datetime]:
    current = start
    while current < end:
        yield current
        current += timedelta(days=1)
        if stop_at_today and current.date() > now(_UTC).date():
            break


_StreamToStart = Callable[[HexStr, Optional[BlockNumber]], Awaitable[int]]

_streamToStart_cache: Final[Dict[HexStr, _StreamToStart]] = {}


def _get_streamToStart(stream_id: HexStr) -> _StreamToStart:
    if streamToStart := _streamToStart_cache.get(stream_id):
        return streamToStart
    with db_session:
        contract: y.Contract = Stream[stream_id].contract.contract  # type: ignore [misc]
    streamToStart = contract.streamToStart.coroutine
    _streamToStart_cache[stream_id] = streamToStart
    return streamToStart


async def _get_start_timestamp(
    stream_id: HexStr, block: Optional[BlockNumber] = None
) -> int:
    streamToStart = _streamToStart_cache.get(stream_id)
    if streamToStart is None:
        streamToStart = await _STREAMS_THREAD.run(_get_streamToStart, stream_id)
    # try:
    return int(await streamToStart(f"0x{stream_id}", block_identifier=block))  # type: ignore [call-arg]
    # except Exception:
    #    return 0


def _pause_stream(stream_id: HexStr) -> None:
    with db_session:
        Stream[stream_id].pause()  # type: ignore [misc]


def _stop_stream(stream_id: str, block: BlockNumber) -> None:
    with db_session:
        Stream[stream_id].stop_stream(block)  # type: ignore [misc]


_block_timestamps: Final[Dict[BlockNumber, UnixTimestamp]] = {}


async def _get_block_timestamp(block: BlockNumber) -> UnixTimestamp:
    if timestamp := _block_timestamps.get(block):
        return timestamp
    timestamp = await dank_mids.eth.get_block_timestamp(block)
    _block_timestamps[block] = timestamp
    return timestamp


"""
class _StreamProcessor(ABC):
    @abstractmethod
    async def _load_streams(self) -> None:
        ...
"""


[docs] @final class LlamaPayProcessor: """ Generalized async processor for DAO stream contracts. Args are passed in at construction time. Supports time-bounded admin periods for filtering. """ handled_events: Final = ( "StreamCreated", "StreamCreatedWithReason", "StreamModified", "StreamPaused", "StreamCancelled", ) skipped_events: Final = ( "PayerDeposit", "PayerWithdraw", "Withdraw", )
[docs] def __init__(self) -> None: self.stream_contracts: Final = {Contract(addr) for addr in factories}
async def _get_streams(self) -> None: await igather( self._load_contract_events(stream_contract) for stream_contract in self.stream_contracts ) async def _load_contract_events(self, stream_contract: y.Contract) -> None: events = decode_logs( await get_logs_asap(stream_contract.address, None, sync=False) ) keys: Set[str] = set(events.keys()) for k in keys: if k not in self.handled_events and k not in self.skipped_events: raise NotImplementedError(f"Need to handle event: {k}") if "StreamCreated" in keys: for event in events["StreamCreated"]: from_address, *_ = event.values() from_address = Address.get_or_insert(from_address).address if not TreasuryWallet.check_membership( from_address, event.block_number ): continue await _STREAMS_THREAD.run(self._get_stream, event) if "StreamCreatedWithReason" in keys: for event in events["StreamCreatedWithReason"]: from_address, *_ = event.values() from_address = Address.get_or_insert(from_address).address if not TreasuryWallet.check_membership( from_address, event.block_number ): continue await _STREAMS_THREAD.run(self._get_stream, event) if "StreamModified" in keys: for event in events["StreamModified"]: from_address, _, _, old_stream_id, *_ = event.values() if not TreasuryWallet.check_membership( from_address, event.block_number ): continue await _STREAMS_THREAD.run( _stop_stream, old_stream_id.hex(), event.block_number ) await _STREAMS_THREAD.run(self._get_stream, event) if "StreamPaused" in keys: for event in events["StreamPaused"]: from_address, *_, stream_id = event.values() if not TreasuryWallet.check_membership( from_address, event.block_number ): continue await _STREAMS_THREAD.run(_pause_stream, stream_id.hex()) if "StreamCancelled" in keys: for event in events["StreamCancelled"]: from_address, *_, stream_id = event.values() if not TreasuryWallet.check_membership( from_address, event.block_number ): continue await _STREAMS_THREAD.run( _stop_stream, stream_id.hex(), event.block_number ) def _get_stream(self, log: _EventItem) -> Stream: with db_session: if log.name == "StreamCreated": from_address, to_address, amount_per_second, stream_id = log.values() reason = None elif log.name == "StreamCreatedWithReason": from_address, to_address, amount_per_second, stream_id, reason = ( log.values() ) elif log.name == "StreamModified": ( from_address, _, _, old_stream_id, to_address, amount_per_second, stream_id, ) = log.values() reason = Stream[old_stream_id.hex()].reason # type: ignore [misc] else: raise NotImplementedError("This is not an appropriate event log.") stream_id_hex = stream_id.hex() try: return Stream[stream_id_hex] # type: ignore [misc] except ObjectNotFound: entity = Stream( stream_id=stream_id_hex, contract=Address.get_dbid(log.address), start_block=log.block_number, token=Token.get_dbid(Contract(log.address).token()), from_address=Address.get_dbid(from_address), to_address=Address.get_dbid(to_address), amount_per_second=amount_per_second, txgroup=must_sort_outbound_txgroup_dbid, ) if reason is not None: entity.reason = reason commit() return entity
[docs] def streams_for_recipient( self, recipient: ChecksumAddress, at_block: Optional[BlockNumber] = None ) -> List[Stream]: with db_session: streams = Stream.select(lambda s: s.to_address.address == recipient) if at_block is None: return list(streams) return [ s for s in streams if (s.end_block is None or at_block <= s.end_block) ]
[docs] def streams_for_token( self, token: ChecksumAddress, include_inactive: bool = False ) -> List[Stream]: with db_session: streams = Stream.select(lambda s: s.token.address.address == token) return ( list(streams) if include_inactive else [s for s in streams if s.is_alive] )
[docs] async def process_streams(self, run_forever: bool = False) -> None: logger.info("Processing stream events and streamed funds...") # Always sync events before processing await self._get_streams() with db_session: streams = [s.stream_id for s in Stream.select()] await tqdm_asyncio.gather( *( self.process_stream(stream_id, run_forever=run_forever) for stream_id in streams ), desc="LlamaPay Streams", )
[docs] async def process_stream( self, stream_id: HexStr, run_forever: bool = False ) -> None: start, end = await _STREAMS_THREAD.run(Stream._get_start_and_end, stream_id) for date_obj in _generate_dates(start, end, stop_at_today=not run_forever): if await self.process_stream_for_date(stream_id, date_obj) is None: return
[docs] async def process_stream_for_date( self, stream_id: HexStr, date_obj: dt.datetime ) -> Optional[StreamedFunds]: entity = await _STREAMS_THREAD.run( StreamedFunds.get_entity, stream_id, date_obj ) if entity: return entity stream_token, start_date = await _STREAMS_THREAD.run( Stream._get_token_and_start_date, stream_id ) check_at = date_obj + timedelta(days=1) - timedelta(seconds=1) if check_at > now(tz=_UTC): await sleep((check_at - now(tz=_UTC)).total_seconds()) while True: try: block = await get_block_at_timestamp(check_at, sync=False) except NoBlockFound: sleep_time = (check_at - now(tz=_UTC)).total_seconds() logger.debug( "no block found for %s, sleeping %ss", check_at, sleep_time ) await sleep(sleep_time) else: break price_fut = create_task(get_price(stream_token, block, sync=False)) start_timestamp = await _get_start_timestamp(stream_id, block) if start_timestamp == 0: if await _STREAMS_THREAD.run(Stream.check_closed, stream_id): price_fut.cancel() return None while start_timestamp == 0: block -= 1 start_timestamp = await _get_start_timestamp(stream_id, block) block_datetime = fromtimestamp(await _get_block_timestamp(block), tz=_UTC) assert block_datetime.date() == date_obj.date() seconds_active = (check_at - block_datetime).seconds is_last_day = True else: seconds_active = int(check_at.timestamp()) - start_timestamp is_last_day = False seconds_active_today = min(seconds_active, _ONE_DAY) if seconds_active_today < _ONE_DAY and not is_last_day: if date_obj.date() != start_date: seconds_active_today = _ONE_DAY with db_session: price = Decimal(await price_fut) entity = await _STREAMS_THREAD.run( StreamedFunds.create_entity, stream_id, date_obj, price, seconds_active_today, is_last_day, ) return entity