Example #1
0
def process_transaction_signatures(
    solana_client_manager: SolanaClientManager,
    db: SessionManager,
    redis: Redis,
    transaction_signatures: List[List[str]],
):
    """Concurrently processes the transactions to update the DB state for reward transfer instructions"""
    last_tx_sig: Optional[str] = None
    last_tx: Optional[RewardManagerTransactionInfo] = None
    if transaction_signatures and transaction_signatures[-1]:
        last_tx_sig = transaction_signatures[-1][0]

    for tx_sig_batch in transaction_signatures:
        logger.info(f"index_rewards_manager.py | processing {tx_sig_batch}")
        batch_start_time = time.time()

        transfer_instructions: List[RewardManagerTransactionInfo] = []
        # Process each batch in parallel
        with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
            parse_sol_tx_futures = {
                executor.submit(
                    fetch_and_parse_sol_rewards_transfer_instruction,
                    solana_client_manager,
                    tx_sig,
                ): tx_sig
                for tx_sig in tx_sig_batch
            }
            for future in concurrent.futures.as_completed(
                    parse_sol_tx_futures):
                try:
                    # No return value expected here so we just ensure all futures are resolved
                    parsed_solana_transfer_instruction = future.result()
                    if parsed_solana_transfer_instruction is not None:
                        transfer_instructions.append(
                            parsed_solana_transfer_instruction)
                        if (last_tx_sig and last_tx_sig ==
                                parsed_solana_transfer_instruction["tx_sig"]):
                            last_tx = parsed_solana_transfer_instruction
                except Exception as exc:
                    logger.error(f"index_rewards_manager.py | {exc}")
                    raise exc
        with db.scoped_session() as session:
            process_batch_sol_reward_manager_txs(session,
                                                 transfer_instructions, redis)
        batch_end_time = time.time()
        batch_duration = batch_end_time - batch_start_time
        logger.info(
            f"index_rewards_manager.py | processed batch {len(tx_sig_batch)} txs in {batch_duration}s"
        )

    if last_tx:
        cache_latest_sol_rewards_manager_db_tx(
            redis,
            {
                "signature": last_tx["tx_sig"],
                "slot": last_tx["slot"],
                "timestamp": last_tx["timestamp"],
            },
        )
    return last_tx
Example #2
0
def get_should_update_trending(db: SessionManager, web3: Web3, redis: Redis,
                               interval_seconds: int) -> Optional[int]:
    """
    Checks if the trending job should re-run based off the last trending run's timestamp and
    the most recently indexed block's timestamp.
    If the most recently indexed block (rounded down to the nearest interval) is `interval_seconds`
    ahead of the last trending job run, then the job should re-run.
    The function returns the an int, representing the timestamp, if the jobs should re-run, else None
    """
    with db.scoped_session() as session:
        current_db_block = (session.query(
            Block.blockhash).filter(Block.is_current == True).first())
        current_block = web3.eth.get_block(current_db_block[0], True)
        current_timestamp = current_block["timestamp"]
        block_datetime = floor_time(datetime.fromtimestamp(current_timestamp),
                                    interval_seconds)

        last_trending_datetime = get_last_trending_datetime(redis)
        if not last_trending_datetime:
            return int(block_datetime.timestamp())

        duration_since_last_index = block_datetime - last_trending_datetime
        if duration_since_last_index.total_seconds() >= interval_seconds:
            return int(block_datetime.timestamp())

    return None
Example #3
0
def db_mock(monkeypatch):
    db = SessionManager("sqlite://", {})

    def get_db_read_replica():
        return db

    monkeypatch.setattr(src.utils.db_session, "get_db_read_replica",
                        get_db_read_replica)

    return db
def process_related_artists_queue(db: SessionManager, redis: Redis):
    next: Union[int, bool] = True
    needed_update_count = 0
    with db.scoped_session() as session:
        while next and needed_update_count < 10:
            next = redis.lpop(INDEX_RELATED_ARTIST_REDIS_QUEUE)
            if next:
                next = int(next)
                logger.debug(
                    f"index_related_artists.py | Checking user_id={next} for related artists recalculation..."
                )
                needed_update, reason = update_related_artist_scores_if_needed(
                    session, next)
                if needed_update:
                    logger.info(
                        f"index_related_artists.py | Updated related artists for user_id={next}"
                    )
                    needed_update_count += 1
                else:
                    logger.info(
                        f"index_related_artists.py | Skipped updating user_id={next} reason={reason}"
                    )
Example #5
0
def configure_celery(flask_app, celery, test_config=None):
    database_url = shared_config["db"]["url"]
    engine_args_literal = ast.literal_eval(
        shared_config["db"]["engine_args_literal"])
    redis_url = shared_config["redis"]["url"]

    if test_config is not None:
        if "db" in test_config:
            if "url" in test_config["db"]:
                database_url = test_config["db"]["url"]

    ipld_interval = int(
        shared_config["discprov"]["blacklist_block_indexing_interval"])
    # default is 5 seconds
    indexing_interval_sec = int(
        shared_config["discprov"]["block_processing_interval_sec"])

    # Update celery configuration
    celery.conf.update(
        imports=[
            "src.tasks.index", "src.tasks.index_blacklist",
            "src.tasks.index_plays", "src.tasks.index_metrics",
            "src.tasks.index_materialized_views",
            "src.tasks.index_network_peers", "src.tasks.index_trending",
            "src.tasks.cache_user_balance", "src.monitors.monitoring_queue",
            "src.tasks.cache_trending_playlists",
            "src.tasks.index_solana_plays", "src.tasks.index_aggregate_views"
        ],
        beat_schedule={
            "update_discovery_provider": {
                "task": "update_discovery_provider",
                "schedule": timedelta(seconds=indexing_interval_sec),
            },
            "update_ipld_blacklist": {
                "task": "update_ipld_blacklist",
                "schedule": timedelta(seconds=ipld_interval),
            },
            "update_play_count": {
                "task": "update_play_count",
                "schedule": timedelta(seconds=60)
            },
            "update_metrics": {
                "task": "update_metrics",
                "schedule": crontab(minute=0, hour="*")
            },
            "aggregate_metrics": {
                "task": "aggregate_metrics",
                "schedule": timedelta(minutes=METRICS_INTERVAL)
            },
            "synchronize_metrics": {
                "task": "synchronize_metrics",
                "schedule": crontab(minute=0, hour=1)
            },
            "update_materialized_views": {
                "task": "update_materialized_views",
                "schedule": timedelta(seconds=300)
            },
            "update_network_peers": {
                "task": "update_network_peers",
                "schedule": timedelta(seconds=30)
            },
            "index_trending": {
                "task": "index_trending",
                "schedule": crontab(minute=15, hour="*")
            },
            "update_user_balances": {
                "task": "update_user_balances",
                "schedule": timedelta(seconds=60)
            },
            "monitoring_queue": {
                "task": "monitoring_queue",
                "schedule": timedelta(seconds=60)
            },
            "cache_trending_playlists": {
                "task": "cache_trending_playlists",
                "schedule": timedelta(minutes=30)
            },
            "index_solana_plays": {
                "task": "index_solana_plays",
                "schedule": timedelta(seconds=5)
            },
            "update_aggregate_user": {
                "task": "update_aggregate_user",
                "schedule": timedelta(seconds=30)
            },
            "update_aggregate_track": {
                "task": "update_aggregate_track",
                "schedule": timedelta(seconds=30)
            },
            "update_aggregate_playlist": {
                "task": "update_aggregate_playlist",
                "schedule": timedelta(seconds=30)
            }
        },
        task_serializer="json",
        accept_content=["json"],
        broker_url=redis_url,
    )

    # Initialize DB object for celery task context
    db = SessionManager(database_url, engine_args_literal)
    logger.info('Database instance initialized!')
    # Initialize IPFS client for celery task context
    ipfs_client = IPFSClient(shared_config["ipfs"]["host"],
                             shared_config["ipfs"]["port"])

    # Initialize Redis connection
    redis_inst = redis.Redis.from_url(url=redis_url)
    # Clear existing locks used in tasks if present
    redis_inst.delete("disc_prov_lock")
    redis_inst.delete("network_peers_lock")
    redis_inst.delete("materialized_view_lock")
    redis_inst.delete("update_metrics_lock")
    redis_inst.delete("update_play_count_lock")
    redis_inst.delete("ipld_blacklist_lock")
    redis_inst.delete("update_discovery_lock")
    redis_inst.delete("aggregate_metrics_lock")
    redis_inst.delete("synchronize_metrics_lock")
    logger.info('Redis instance initialized!')

    # Initialize custom task context with database object
    class DatabaseTask(Task):
        def __init__(self, *args, **kwargs):
            self._db = db
            self._web3_provider = web3
            self._abi_values = abi_values
            self._shared_config = shared_config
            self._ipfs_client = ipfs_client
            self._redis = redis_inst
            self._eth_web3_provider = eth_web3
            self._solana_client = solana_client

        @property
        def abi_values(self):
            return self._abi_values

        @property
        def web3(self):
            return self._web3_provider

        @property
        def db(self):
            return self._db

        @property
        def shared_config(self):
            return self._shared_config

        @property
        def ipfs_client(self):
            return self._ipfs_client

        @property
        def redis(self):
            return self._redis

        @property
        def eth_web3(self):
            return self._eth_web3_provider

        @property
        def solana_client(self):
            return self._solana_client

    celery.autodiscover_tasks(["src.tasks"], "index", True)

    # Subclassing celery task with discovery provider context
    # Provided through properties defined in 'DatabaseTask'
    celery.Task = DatabaseTask

    celery.finalize()
Example #6
0
def configure_flask(test_config, app, mode="app"):
    with app.app_context():
        app.iniconfig.read(config_files)

    # custom JSON serializer for timestamps
    class TimestampJSONEncoder(JSONEncoder):
        # pylint: disable=E0202
        def default(self, o):
            if isinstance(o, datetime.datetime):
                # ISO-8601 timestamp format
                return o.strftime("%Y-%m-%dT%H:%M:%S Z")
            return JSONEncoder.default(self, o)

    app.json_encoder = TimestampJSONEncoder

    database_url = app.config["db"]["url"]
    if test_config is not None:
        if "db" in test_config:
            if "url" in test_config["db"]:
                database_url = test_config["db"]["url"]

    # Sometimes ECS latency causes the create_database function to fail because db connection is not ready
    # Give it some more time to get set up, up to 5 times
    i = 0
    while i < 5:
        try:
            # Create database if necessary
            if not database_exists(database_url):
                create_database(database_url)
            else:
                break
        except exc.OperationalError as e:
            if "could not connect to server" in str(e):
                logger.warning(
                    "DB connection isn't up yet...setting a temporary timeout and trying again"
                )
                time.sleep(10)
            else:
                raise e

        i += 1

    if test_config is not None:
        # load the test config if passed in
        app.config.update(test_config)

    app.db_session_manager = SessionManager(
        app.config["db"]["url"],
        ast.literal_eval(app.config["db"]["engine_args_literal"]),
    )

    app.db_read_replica_session_manager = SessionManager(
        app.config["db"]["url_read_replica"],
        ast.literal_eval(app.config["db"]["engine_args_literal"]),
    )

    register_exception_handlers(app)
    app.register_blueprint(queries.bp)
    app.register_blueprint(search.bp)
    app.register_blueprint(search_queries.bp)
    app.register_blueprint(notifications.bp)
    app.register_blueprint(health_check.bp)
    app.register_blueprint(block_confirmation.bp)

    app.register_blueprint(api_v1.bp)
    app.register_blueprint(api_v1.bp_full)

    return app
Example #7
0
def enqueue_trending_challenges(db: SessionManager, redis: Redis,
                                challenge_bus: ChallengeEventBus,
                                date: datetime):
    logger.info(
        "calculate_trending_challenges.py | Start calculating trending challenges"
    )
    update_start = time.time()
    with db.scoped_session(
    ) as session, challenge_bus.use_scoped_dispatch_queue():

        latest_blocknumber = get_latest_blocknumber_via_redis(session, redis)
        if latest_blocknumber is None:
            logger.error(
                "calculate_trending_challenges.py | Unable to get latest block number"
            )
            return

        trending_track_versions = trending_strategy_factory.get_versions_for_type(
            TrendingType.TRACKS).keys()

        time_range = "week"
        for version in trending_track_versions:
            strategy = trending_strategy_factory.get_strategy(
                TrendingType.TRACKS, version)
            top_tracks = _get_trending_tracks_with_session(
                session, {"time": time_range}, strategy)
            top_tracks = top_tracks[:TRENDING_LIMIT]
            dispatch_trending_challenges(
                challenge_bus,
                ChallengeEvent.trending_track,
                latest_blocknumber,
                top_tracks,
                version,
                date,
                TrendingType.TRACKS,
            )

        # Cache underground trending
        underground_trending_versions = trending_strategy_factory.get_versions_for_type(
            TrendingType.UNDERGROUND_TRACKS).keys()
        for version in underground_trending_versions:
            strategy = trending_strategy_factory.get_strategy(
                TrendingType.UNDERGROUND_TRACKS, version)
            underground_args: GetUndergroundTrendingTrackcArgs = {
                "offset": 0,
                "limit": TRENDING_LIMIT,
            }
            top_tracks = _get_underground_trending_with_session(
                session, underground_args, strategy, False)

            dispatch_trending_challenges(
                challenge_bus,
                ChallengeEvent.trending_underground,
                latest_blocknumber,
                top_tracks,
                version,
                date,
                TrendingType.UNDERGROUND_TRACKS,
            )

        trending_playlist_versions = trending_strategy_factory.get_versions_for_type(
            TrendingType.PLAYLISTS).keys()
        for version in trending_playlist_versions:
            strategy = trending_strategy_factory.get_strategy(
                TrendingType.PLAYLISTS, version)
            playlists_args: GetTrendingPlaylistsArgs = {
                "limit": TRENDING_LIMIT,
                "offset": 0,
                "time": time_range,
            }
            trending_playlists = _get_trending_playlists_with_session(
                session, playlists_args, strategy, False)
            for idx, playlist in enumerate(trending_playlists):
                challenge_bus.dispatch(
                    ChallengeEvent.trending_playlist,
                    latest_blocknumber,
                    playlist["playlist_owner_id"],
                    {
                        "id": playlist["playlist_id"],
                        "user_id": playlist["playlist_owner_id"],
                        "rank": idx + 1,
                        "type": str(TrendingType.PLAYLISTS),
                        "version": str(version),
                        "week": date_to_week(date),
                    },
                )

    update_end = time.time()
    update_total = update_end - update_start
    logger.info(
        f"calculate_trending_challenges.py | Finished calculating trending in {update_total} seconds"
    )
Example #8
0
def configure_flask(test_config, app, mode="app"):
    with app.app_context():
        app.iniconfig.read(config_files)

    # custom JSON serializer for timestamps
    class TimestampJSONEncoder(JSONEncoder):
        # pylint: disable=E0202
        def default(self, o):
            if isinstance(o, datetime.datetime):
                # ISO-8601 timestamp format
                return o.strftime("%Y-%m-%dT%H:%M:%S Z")
            return JSONEncoder.default(self, o)

    app.json_encoder = TimestampJSONEncoder

    database_url = app.config["db"]["url"]
    if test_config is not None:
        if "db" in test_config:
            if "url" in test_config["db"]:
                database_url = test_config["db"]["url"]

    # Sometimes ECS latency causes the create_database function to fail because db connection is not ready
    # Give it some more time to get set up, up to 5 times
    i = 0
    while i < 5:
        try:
            # Create database if necessary
            if not database_exists(database_url):
                create_database(database_url)
            else:
                break
        except exc.OperationalError as e:
            if "could not connect to server" in str(e):
                logger.warning(
                    "DB connection isn't up yet...setting a teporary timeout and trying again"
                )
                time.sleep(10)
            else:
                raise e

        i += 1

    # Conditionally perform alembic database upgrade to HEAD during
    # flask initialization
    if mode == "app":
        alembic_dir = os.getcwd()
        alembic_config = alembic.config.Config(f"{alembic_dir}/alembic.ini")
        alembic_config.set_main_option("sqlalchemy.url", str(database_url))
        with helpers.cd(alembic_dir):
            alembic.command.upgrade(alembic_config, "head")

    if test_config is not None:
        # load the test config if passed in
        app.config.update(test_config)

    app.db_session_manager = SessionManager(
        app.config["db"]["url"],
        ast.literal_eval(app.config["db"]["engine_args_literal"]),
    )
    with app.db_session_manager.scoped_session() as session:
        set_search_similarity(session)

    app.db_read_replica_session_manager = SessionManager(
        app.config["db"]["url_read_replica"],
        ast.literal_eval(app.config["db"]["engine_args_literal"]),
    )
    with app.db_read_replica_session_manager.scoped_session() as session:
        set_search_similarity(session)

    exceptions.register_exception_handlers(app)
    app.register_blueprint(queries.bp)
    app.register_blueprint(trending.bp)
    app.register_blueprint(search.bp)
    app.register_blueprint(search_queries.bp)
    app.register_blueprint(notifications.bp)
    app.register_blueprint(health_check.bp)

    app.register_blueprint(api_v1.bp)
    app.register_blueprint(api_v1.bp_full)

    return app
Example #9
0
def index_trending(self, db: SessionManager, redis: Redis, timestamp):
    logger.info("index_trending.py | starting indexing")
    update_start = time.time()
    metric = PrometheusMetric(
        "index_trending_runtime_seconds",
        "Runtimes for src.task.index_trending:index_trending()",
    )
    with db.scoped_session() as session:
        genres = get_genres(session)

        # Make sure to cache empty genre
        genres.append(None)  # type: ignore

        trending_track_versions = trending_strategy_factory.get_versions_for_type(
            TrendingType.TRACKS).keys()

        update_view(session, AGGREGATE_INTERVAL_PLAYS)
        update_view(session, TRENDING_PARAMS)
        for version in trending_track_versions:
            strategy = trending_strategy_factory.get_strategy(
                TrendingType.TRACKS, version)
            if strategy.use_mat_view:
                strategy.update_track_score_query(session)

        for version in trending_track_versions:
            strategy = trending_strategy_factory.get_strategy(
                TrendingType.TRACKS, version)
            for genre in genres:
                for time_range in time_ranges:
                    cache_start_time = time.time()
                    if strategy.use_mat_view:
                        res = generate_unpopulated_trending_from_mat_views(
                            session, genre, time_range, strategy)
                    else:
                        res = generate_unpopulated_trending(
                            session, genre, time_range, strategy)
                    key = make_trending_cache_key(time_range, genre, version)
                    set_json_cached_key(redis, key, res)
                    cache_end_time = time.time()
                    total_time = cache_end_time - cache_start_time
                    logger.info(
                        f"index_trending.py | Cached trending ({version.name} version) \
                        for {genre}-{time_range} in {total_time} seconds")

        # Cache underground trending
        underground_trending_versions = trending_strategy_factory.get_versions_for_type(
            TrendingType.UNDERGROUND_TRACKS).keys()
        for version in underground_trending_versions:
            strategy = trending_strategy_factory.get_strategy(
                TrendingType.UNDERGROUND_TRACKS, version)
            cache_start_time = time.time()
            res = make_get_unpopulated_tracks(session, redis, strategy)()
            key = make_underground_trending_cache_key(version)
            set_json_cached_key(redis, key, res)
            cache_end_time = time.time()
            total_time = cache_end_time - cache_start_time
            logger.info(
                f"index_trending.py | Cached underground trending ({version.name} version) \
                in {total_time} seconds")

    update_end = time.time()
    update_total = update_end - update_start
    metric.save_time()
    logger.info(
        f"index_trending.py | Finished indexing trending in {update_total} seconds",
        extra={
            "job": "index_trending",
            "total_time": update_total
        },
    )
    # Update cache key to track the last time trending finished indexing
    redis.set(trending_tracks_last_completion_redis_key, int(update_end))
    set_last_trending_datetime(redis, timestamp)
Example #10
0
def get_transaction_signatures(
    solana_client_manager: SolanaClientManager,
    db: SessionManager,
    program: str,
    get_latest_slot: Callable[[Session], int],
    check_tx_exists: Callable[[Session, str], bool],
    min_slot=None,
) -> List[List[str]]:
    """Fetches next batch of transaction signature offset from the previous latest processed slot

    Fetches the latest processed slot for the rewards manager program
    Iterates backwards from the current tx until an intersection is found with the latest processed slot
    Returns the next set of transaction signature from the current offset slot to process
    """
    # List of signatures that will be populated as we traverse recent operations
    transaction_signatures = []

    last_tx_signature = None

    # Loop exit condition
    intersection_found = False

    # Query for solana transactions until an intersection is found
    with db.scoped_session() as session:
        latest_processed_slot = get_latest_slot(session)
        while not intersection_found:
            transactions_history = solana_client_manager.get_signatures_for_address(
                program,
                before=last_tx_signature,
                limit=FETCH_TX_SIGNATURES_BATCH_SIZE)

            transactions_array = transactions_history["result"]
            if not transactions_array:
                intersection_found = True
                logger.info(
                    f"index_rewards_manager.py | No transactions found before {last_tx_signature}"
                )
            else:
                # Current batch of transactions
                transaction_signature_batch = []
                for tx_info in transactions_array:
                    tx_sig = tx_info["signature"]
                    tx_slot = tx_info["slot"]
                    logger.info(
                        f"index_rewards_manager.py | Processing tx={tx_sig} | slot={tx_slot}"
                    )
                    if tx_info["slot"] > latest_processed_slot:
                        transaction_signature_batch.append(tx_sig)
                    elif tx_info["slot"] <= latest_processed_slot and (
                            min_slot is None or tx_info["slot"] > min_slot):
                        # Check the tx signature for any txs in the latest batch,
                        # and if not present in DB, add to processing
                        logger.info(
                            f"index_rewards_manager.py | Latest slot re-traversal\
                            slot={tx_slot}, sig={tx_sig},\
                            latest_processed_slot(db)={latest_processed_slot}")
                        exists = check_tx_exists(session, tx_sig)
                        if exists:
                            intersection_found = True
                            break
                        # Ensure this transaction is still processed
                        transaction_signature_batch.append(tx_sig)

                # Restart processing at the end of this transaction signature batch
                last_tx = transactions_array[-1]
                last_tx_signature = last_tx["signature"]

                # Append batch of processed signatures
                if transaction_signature_batch:
                    transaction_signatures.append(transaction_signature_batch)

                # Ensure processing does not grow unbounded
                if len(transaction_signatures) > TX_SIGNATURES_MAX_BATCHES:
                    # Only take the oldest transaction from the transaction_signatures array
                    # transaction_signatures is sorted from newest to oldest
                    transaction_signatures = transaction_signatures[
                        -TX_SIGNATURES_RESIZE_LENGTH:]

    # Reverse batches aggregated so oldest transactions are processed first
    transaction_signatures.reverse()
    return transaction_signatures
Example #11
0
def parse_sol_tx_batch(
    db: SessionManager,
    solana_client_manager: SolanaClientManager,
    redis: Redis,
    tx_sig_batch_records: List[ConfirmedSignatureForAddressResult],
    solana_logger: SolanaIndexingLogger,
):
    """
    Parse a batch of solana transactions in parallel by calling parse_spl_token_transaction
    with a ThreaPoolExecutor

    This function also has a recursive retry upto a certain limit in case a future doesn't complete
    within the alloted time. It clears the futures thread queue and the batch is retried
    """
    batch_start_time = time.time()
    # Last record in this batch to be cached
    # Important to note that the batch records are in time DESC order
    updated_root_accounts: Set[str] = set()
    updated_token_accounts: Set[str] = set()
    # Process each batch in parallel
    with concurrent.futures.ThreadPoolExecutor() as executor:
        parse_sol_tx_futures = {
            executor.submit(
                parse_spl_token_transaction,
                solana_client_manager,
                tx_sig,
            ): tx_sig
            for tx_sig in tx_sig_batch_records
        }
        try:
            for future in concurrent.futures.as_completed(parse_sol_tx_futures,
                                                          timeout=45):
                _, root_accounts, token_accounts = future.result()
                if root_accounts or token_accounts:
                    updated_root_accounts.update(root_accounts)
                    updated_token_accounts.update(token_accounts)

        except Exception as exc:
            logger.error(
                f"index_spl_token.py | Error parsing sol spl token transaction: {exc}"
            )
            raise exc

    update_user_ids: Set[int] = set()
    with db.scoped_session() as session:
        if updated_token_accounts:
            user_bank_subquery = session.query(
                UserBankAccount.ethereum_address).filter(
                    UserBankAccount.bank_account.in_(
                        list(updated_token_accounts)))

            user_result = (session.query(User.user_id).filter(
                User.is_current == True,
                User.wallet.in_(user_bank_subquery)).all())
            user_set = {user_id for [user_id] in user_result}
            update_user_ids.update(user_set)

        if updated_root_accounts:
            # Remove the user bank owner
            user_bank_owner, _ = get_base_address(SPL_TOKEN_PUBKEY,
                                                  USER_BANK_PUBKEY)
            updated_root_accounts.discard(str(user_bank_owner))

            associated_wallet_result = (session.query(
                AssociatedWallet.user_id).filter(
                    AssociatedWallet.is_current == True,
                    AssociatedWallet.is_delete == False,
                    AssociatedWallet.chain == WalletChain.sol,
                    AssociatedWallet.wallet.in_(list(updated_root_accounts)),
                )).all()
            associated_wallet_set = {
                user_id
                for [user_id] in associated_wallet_result
            }
            update_user_ids.update(associated_wallet_set)

        user_ids = list(update_user_ids)
        if user_ids:
            logger.info(
                f"index_spl_token.py | Enqueueing user ids {user_ids} to immediate balance refresh queue"
            )
            enqueue_immediate_balance_refresh(redis, user_ids)

        if tx_sig_batch_records:
            last_tx = tx_sig_batch_records[0]

            last_scanned_slot = last_tx["slot"]
            last_scanned_signature = last_tx["signature"]
            solana_logger.add_log(
                f"Updating last_scanned_slot to {last_scanned_slot} and signature to {last_scanned_signature}"
            )
            cache_latest_spl_audio_db_tx(
                redis,
                {
                    "signature": last_scanned_signature,
                    "slot": last_scanned_slot,
                    "timestamp": last_tx["blockTime"],
                },
            )

            record = session.query(SPLTokenTransaction).first()
            if record:
                record.last_scanned_slot = last_scanned_slot
                record.signature = last_scanned_signature
            else:
                record = SPLTokenTransaction(
                    last_scanned_slot=last_scanned_slot,
                    signature=last_scanned_signature,
                )
            session.add(record)

    batch_end_time = time.time()
    batch_duration = batch_end_time - batch_start_time
    solana_logger.add_log(
        f"processed batch {len(tx_sig_batch_records)} txs in {batch_duration}s"
    )

    return (update_user_ids, updated_root_accounts, updated_token_accounts)
Example #12
0
def configure_celery(flask_app, celery, test_config=None):
    database_url = shared_config["db"]["url"]
    engine_args_literal = ast.literal_eval(
        shared_config["db"]["engine_args_literal"])
    redis_url = shared_config["redis"]["url"]

    if test_config is not None:
        if "db" in test_config:
            if "url" in test_config["db"]:
                database_url = test_config["db"]["url"]

    # Update celery configuration
    celery.conf.update(
        imports=["src.tasks.index", "src.tasks.index_blacklist",
                 "src.tasks.index_cache", "src.tasks.index_plays", "src.tasks.index_metrics"],
        beat_schedule={
            "update_discovery_provider": {
                "task": "update_discovery_provider",
                "schedule": timedelta(seconds=5),
            },
            "update_ipld_blacklist": {
                "task": "update_ipld_blacklist",
                "schedule": timedelta(seconds=60),
            },
            "update_cache": {
                "task": "update_discovery_cache",
                "schedule": timedelta(seconds=60)
            },
            "update_play_count": {
                "task": "update_play_count",
                "schedule": timedelta(seconds=10)
            },
            "update_metrics": {
                "task": "update_metrics",
                "schedule": crontab(minute=0, hour="*")
            }
        },
        task_serializer="json",
        accept_content=["json"],
        broker_url=redis_url,
    )

    # Initialize DB object for celery task context
    db = SessionManager(database_url, engine_args_literal)
    logger.info('Database instance initialized!')

    # Initialize IPFS client for celery task context
    gateway_addrs = shared_config["ipfs"]["gateway_hosts"].split(',')
    gateway_addrs.append(
        shared_config["discprov"]["user_metadata_service_url"])
    logger.warning(f"__init__.py | {gateway_addrs}")
    ipfs_client = IPFSClient(
        shared_config["ipfs"]["host"], shared_config["ipfs"]["port"], gateway_addrs
    )

    # Initialize Redis connection
    redis_inst = redis.Redis.from_url(url=redis_url)

    # Clear existing lock if present
    redis_inst.delete("disc_prov_lock")
    logger.info('Redis instance initialized!')

    # Initialize custom task context with database object
    class DatabaseTask(Task):
        def __init__(self, *args, **kwargs):
            self._db = db
            self._web3_provider = web3
            self._abi_values = abi_values
            self._shared_config = shared_config
            self._ipfs_client = ipfs_client
            self._redis = redis_inst

        @property
        def abi_values(self):
            return self._abi_values

        @property
        def web3(self):
            return self._web3_provider

        @property
        def db(self):
            return self._db

        @property
        def shared_config(self):
            return self._shared_config

        @property
        def ipfs_client(self):
            return self._ipfs_client

        @property
        def redis(self):
            return self._redis

    celery.autodiscover_tasks(["src.tasks"], "index", True)

    # Subclassing celery task with discovery provider context
    # Provided through properties defined in 'DatabaseTask'
    celery.Task = DatabaseTask

    celery.finalize()
def refresh_user_ids(
    redis: Redis,
    db: SessionManager,
    token_contract,
    delegate_manager_contract,
    staking_contract,
    eth_web3,
    waudio_token,
):
    with db.scoped_session() as session:
        lazy_refresh_user_ids = get_lazy_refresh_user_ids(
            redis, session)[:MAX_LAZY_REFRESH_USER_IDS]
        immediate_refresh_user_ids = get_immediate_refresh_user_ids(redis)

        logger.info(
            f"cache_user_balance.py | Starting refresh with {len(lazy_refresh_user_ids)} "
            f"lazy refresh user_ids: {lazy_refresh_user_ids} and {len(immediate_refresh_user_ids)} "
            f"immediate refresh user_ids: {immediate_refresh_user_ids}")
        all_user_ids = lazy_refresh_user_ids + immediate_refresh_user_ids
        user_ids = list(set(all_user_ids))

        existing_user_balances: List[UserBalance] = ((
            session.query(UserBalance)).filter(
                UserBalance.user_id.in_(user_ids)).all())
        all_user_balance = existing_user_balances

        # Balances from current user lookup may
        # not be present in the db, so make those
        not_present_set = set(user_ids) - {
            user.user_id
            for user in existing_user_balances
        }
        new_balances: List[UserBalance] = [
            UserBalance(
                user_id=user_id,
                balance="0",
                associated_wallets_balance="0",
                associated_sol_wallets_balance="0",
            ) for user_id in not_present_set
        ]
        if new_balances:
            session.add_all(new_balances)
            all_user_balance = existing_user_balances + new_balances
            logger.info(
                f"cache_user_balance.py | adding new users: {not_present_set}")

        user_balances = {user.user_id: user for user in all_user_balance}

        # Grab the users & associated_wallets we need to refresh
        user_associated_wallet_query: List[Tuple[int, str, str, str]] = (
            session.query(
                User.user_id,
                User.wallet,
                AssociatedWallet.wallet,
                AssociatedWallet.chain,
            ).outerjoin(
                AssociatedWallet,
                and_(
                    AssociatedWallet.user_id == User.user_id,
                    AssociatedWallet.is_current == True,
                    AssociatedWallet.is_delete == False,
                ),
            ).filter(
                User.user_id.in_(user_ids),
                User.is_current == True,
            ).all())

        user_bank_accounts_query: List[Tuple[int, str]] = (session.query(
            User.user_id, UserBankAccount.bank_account).join(
                UserBankAccount,
                UserBankAccount.ethereum_address == User.wallet).filter(
                    User.user_id.in_(user_ids),
                    User.is_current == True,
                ).all())
        user_id_bank_accounts = dict(user_bank_accounts_query)

        # Combine query results for user bank, associated wallets,
        # and primary owner wallet into a single metadata list
        user_id_metadata: Dict[int, UserWalletMetadata] = {}

        for user in user_associated_wallet_query:
            user_id, user_wallet, associated_wallet, wallet_chain = user
            if user_id not in user_id_metadata:
                user_id_metadata[user_id] = {
                    "owner_wallet": user_wallet,
                    "associated_wallets": {
                        "eth": [],
                        "sol": []
                    },
                    "bank_account": None,
                }
                if user_id in user_id_bank_accounts:
                    user_id_metadata[user_id][
                        "bank_account"] = user_id_bank_accounts[user_id]
            if associated_wallet:
                if user_id in user_id_metadata:
                    user_id_metadata[user_id]["associated_wallets"][
                        wallet_chain  # type: ignore
                    ].append(associated_wallet)

        logger.info(
            f"cache_user_balance.py | fetching for {len(user_associated_wallet_query)} users: {user_ids}"
        )

        # mapping of user_id => balance change
        needs_balance_change_update: Dict[int, Dict] = {}

        # Fetch balances
        # pylint: disable=too-many-nested-blocks
        for user_id, wallets in user_id_metadata.items():
            try:
                owner_wallet = wallets["owner_wallet"]
                owner_wallet = eth_web3.toChecksumAddress(owner_wallet)
                owner_wallet_balance = token_contract.functions.balanceOf(
                    owner_wallet).call()
                associated_balance = 0
                waudio_balance: str = "0"
                associated_sol_balance = 0

                if "associated_wallets" in wallets:
                    for wallet in wallets["associated_wallets"]["eth"]:
                        wallet = eth_web3.toChecksumAddress(wallet)
                        balance = token_contract.functions.balanceOf(
                            wallet).call()
                        delegation_balance = (
                            delegate_manager_contract.functions.
                            getTotalDelegatorStake(wallet).call())
                        stake_balance = staking_contract.functions.totalStakedFor(
                            wallet).call()
                        associated_balance += (balance + delegation_balance +
                                               stake_balance)
                    if waudio_token is not None:
                        for wallet in wallets["associated_wallets"]["sol"]:
                            try:
                                root_sol_account = PublicKey(wallet)
                                derived_account, _ = PublicKey.find_program_address(
                                    [
                                        bytes(root_sol_account),
                                        bytes(SPL_TOKEN_ID_PK),
                                        bytes(WAUDIO_MINT_PUBKEY
                                              ),  # type: ignore
                                    ],
                                    ASSOCIATED_TOKEN_PROGRAM_ID_PK,
                                )

                                bal_info = waudio_token.get_balance(
                                    derived_account)
                                associated_waudio_balance: str = bal_info[
                                    "result"]["value"]["amount"]
                                associated_sol_balance += int(
                                    associated_waudio_balance)
                            except Exception as e:
                                logger.error(
                                    " ".join([
                                        "cache_user_balance.py | Error fetching associated ",
                                        "wallet balance for user %s, wallet %s: %s",
                                    ]),
                                    user_id,
                                    wallet,
                                    e,
                                )

                if wallets["bank_account"] is not None:
                    if waudio_token is None:
                        logger.error(
                            "cache_user_balance.py | Missing Required SPL Confirguration"
                        )
                    else:
                        bal_info = waudio_token.get_balance(
                            PublicKey(wallets["bank_account"]))
                        waudio_balance = bal_info["result"]["value"]["amount"]

                # update the balance on the user model
                user_balance = user_balances[user_id]

                # Convert Sol balances to wei
                waudio_in_wei = to_wei(waudio_balance)
                assoc_sol_balance_in_wei = to_wei(associated_sol_balance)
                user_waudio_in_wei = to_wei(user_balance.waudio)
                user_assoc_sol_balance_in_wei = to_wei(
                    user_balance.associated_sol_wallets_balance)

                # Get values for user balance change
                current_total_balance = (owner_wallet_balance +
                                         associated_balance + waudio_in_wei +
                                         assoc_sol_balance_in_wei)
                prev_total_balance = (
                    int(user_balance.balance) +
                    int(user_balance.associated_wallets_balance) +
                    user_waudio_in_wei + user_assoc_sol_balance_in_wei)

                # Write to user_balance_changes table
                needs_balance_change_update[user_id] = {
                    "user_id": user_id,
                    "blocknumber": eth_web3.eth.block_number,
                    "current_balance": str(current_total_balance),
                    "previous_balance": str(prev_total_balance),
                }

                user_balance.balance = str(owner_wallet_balance)
                user_balance.associated_wallets_balance = str(
                    associated_balance)
                user_balance.waudio = waudio_balance
                user_balance.associated_sol_wallets_balance = str(
                    associated_sol_balance)

            except Exception as e:
                logger.error(
                    f"cache_user_balance.py | Error fetching balance for user {user_id}: {(e)}"
                )

        # Outside the loop, batch update the UserBalanceChanges:

        # Get existing user balances
        user_balance_ids = list(needs_balance_change_update.keys())
        existing_user_balance_changes: List[UserBalanceChange] = (
            session.query(UserBalanceChange).filter(
                UserBalanceChange.user_id.in_(user_balance_ids)).all())
        # Find all the IDs that don't already exist in the DB
        to_create_ids = set(user_balance_ids) - {
            e.user_id
            for e in existing_user_balance_changes
        }
        logger.info(
            f"cache_user_balance.py | UserBalanceChanges needing update: {user_balance_ids},\
                    existing: {[e.user_id for e in existing_user_balance_changes]}, to create: {to_create_ids}"
        )

        # Create new entries for those IDs
        balance_changes_to_add = [
            UserBalanceChange(
                user_id=user_id,
                blocknumber=needs_balance_change_update[user_id]
                ["blocknumber"],
                current_balance=needs_balance_change_update[user_id]
                ["current_balance"],
                previous_balance=needs_balance_change_update[user_id]
                ["previous_balance"],
            ) for user_id in to_create_ids
        ]
        session.add_all(balance_changes_to_add)
        # Lastly, update all the existing entries
        for change in existing_user_balance_changes:
            new_values = needs_balance_change_update[change.user_id]
            change.blocknumber = new_values["blocknumber"]
            change.current_balance = new_values["current_balance"]
            change.previous_balance = new_values["previous_balance"]

        # Commit the new balances
        session.commit()

        # Remove the fetched balances from Redis set
        logger.info(
            f"cache_user_balance.py | Got balances for {len(user_associated_wallet_query)} users, removing from Redis."
        )
        if lazy_refresh_user_ids:
            redis.srem(LAZY_REFRESH_REDIS_PREFIX, *lazy_refresh_user_ids)
        if immediate_refresh_user_ids:
            redis.srem(IMMEDIATE_REFRESH_REDIS_PREFIX,
                       *immediate_refresh_user_ids)
Example #14
0
def configure_celery(celery, test_config=None):
    database_url = shared_config["db"]["url"]
    redis_url = shared_config["redis"]["url"]

    if test_config is not None:
        if "db" in test_config:
            if "url" in test_config["db"]:
                database_url = test_config["db"]["url"]

    ipld_interval = int(
        shared_config["discprov"]["blacklist_block_indexing_interval"])
    # default is 5 seconds
    indexing_interval_sec = int(
        shared_config["discprov"]["block_processing_interval_sec"])

    # Update celery configuration
    celery.conf.update(
        imports=[
            "src.tasks.index",
            "src.tasks.index_blacklist",
            "src.tasks.index_metrics",
            "src.tasks.index_materialized_views",
            "src.tasks.aggregates.index_aggregate_plays",
            "src.tasks.index_aggregate_monthly_plays",
            "src.tasks.index_hourly_play_counts",
            "src.tasks.vacuum_db",
            "src.tasks.index_network_peers",
            "src.tasks.index_trending",
            "src.tasks.cache_user_balance",
            "src.monitors.monitoring_queue",
            "src.tasks.cache_trending_playlists",
            "src.tasks.index_solana_plays",
            "src.tasks.index_aggregate_views",
            "src.tasks.index_aggregate_user",
            "src.tasks.aggregates.index_aggregate_track",
            "src.tasks.index_challenges",
            "src.tasks.index_user_bank",
            "src.tasks.index_eth",
            "src.tasks.index_oracles",
            "src.tasks.index_rewards_manager",
            "src.tasks.index_related_artists",
            "src.tasks.calculate_trending_challenges",
            "src.tasks.index_listen_count_milestones",
            "src.tasks.user_listening_history.index_user_listening_history",
            "src.tasks.prune_plays",
            "src.tasks.index_spl_token",
            "src.tasks.index_solana_user_data",
            "src.tasks.index_aggregate_tips",
            "src.tasks.index_reactions",
        ],
        beat_schedule={
            "update_discovery_provider": {
                "task": "update_discovery_provider",
                "schedule": timedelta(seconds=indexing_interval_sec),
            },
            "update_ipld_blacklist": {
                "task": "update_ipld_blacklist",
                "schedule": timedelta(seconds=ipld_interval),
            },
            "update_metrics": {
                "task": "update_metrics",
                "schedule": crontab(minute=0, hour="*"),
            },
            "aggregate_metrics": {
                "task": "aggregate_metrics",
                "schedule": timedelta(minutes=METRICS_INTERVAL),
            },
            "synchronize_metrics": {
                "task": "synchronize_metrics",
                "schedule": timedelta(minutes=SYNCHRONIZE_METRICS_INTERVAL),
            },
            "update_materialized_views": {
                "task": "update_materialized_views",
                "schedule": timedelta(seconds=300),
            },
            "update_aggregate_plays": {
                "task": "update_aggregate_plays",
                "schedule": timedelta(seconds=15),
            },
            "index_hourly_play_counts": {
                "task": "index_hourly_play_counts",
                "schedule": timedelta(seconds=30),
            },
            "vacuum_db": {
                "task": "vacuum_db",
                "schedule": timedelta(days=1),
            },
            "update_network_peers": {
                "task": "update_network_peers",
                "schedule": timedelta(seconds=30),
            },
            "index_trending": {
                "task": "index_trending",
                "schedule": timedelta(seconds=10),
            },
            "update_user_balances": {
                "task": "update_user_balances",
                "schedule": timedelta(seconds=60),
            },
            "monitoring_queue": {
                "task": "monitoring_queue",
                "schedule": timedelta(seconds=60),
            },
            "cache_trending_playlists": {
                "task": "cache_trending_playlists",
                "schedule": timedelta(minutes=30),
            },
            "index_solana_plays": {
                "task": "index_solana_plays",
                "schedule": timedelta(seconds=5),
            },
            "update_aggregate_user": {
                "task": "update_aggregate_user",
                "schedule": timedelta(seconds=30),
            },
            "update_aggregate_track": {
                "task": "update_aggregate_track",
                "schedule": timedelta(seconds=30),
            },
            "update_aggregate_playlist": {
                "task": "update_aggregate_playlist",
                "schedule": timedelta(seconds=30),
            },
            "index_user_bank": {
                "task": "index_user_bank",
                "schedule": timedelta(seconds=5),
            },
            "index_challenges": {
                "task": "index_challenges",
                "schedule": timedelta(seconds=5),
            },
            "index_eth": {
                "task": "index_eth",
                "schedule": timedelta(seconds=10),
            },
            "index_oracles": {
                "task": "index_oracles",
                "schedule": timedelta(minutes=5),
            },
            "index_rewards_manager": {
                "task": "index_rewards_manager",
                "schedule": timedelta(seconds=5),
            },
            "index_related_artists": {
                "task": "index_related_artists",
                "schedule": timedelta(seconds=60),
            },
            "index_listen_count_milestones": {
                "task": "index_listen_count_milestones",
                "schedule": timedelta(seconds=5),
            },
            "index_user_listening_history": {
                "task": "index_user_listening_history",
                "schedule": timedelta(seconds=5),
            },
            "index_aggregate_monthly_plays": {
                "task": "index_aggregate_monthly_plays",
                "schedule": crontab(minute=0, hour=0),  # daily at midnight
            },
            "prune_plays": {
                "task": "prune_plays",
                "schedule": crontab(
                    minute="*/15",
                    hour="14, 15",
                ),  # 8x a day during non peak hours
            },
            "index_spl_token": {
                "task": "index_spl_token",
                "schedule": timedelta(seconds=5),
            },
            "index_aggregate_tips": {
                "task": "index_aggregate_tips",
                "schedule": timedelta(seconds=5),
            },
            "index_reactions": {
                "task": "index_reactions",
                "schedule": timedelta(seconds=5),
            }
            # UNCOMMENT BELOW FOR MIGRATION DEV WORK
            # "index_solana_user_data": {
            #     "task": "index_solana_user_data",
            #     "schedule": timedelta(seconds=5),
            # },
        },
        task_serializer="json",
        accept_content=["json"],
        broker_url=redis_url,
    )

    # Initialize DB object for celery task context
    db = SessionManager(
        database_url,
        ast.literal_eval(shared_config["db"]["engine_args_literal"]))
    logger.info("Database instance initialized!")

    # Initialize Redis connection
    redis_inst = redis.Redis.from_url(url=redis_url)

    # Initialize CIDMetadataClient for celery task context
    cid_metadata_client = CIDMetadataClient(
        eth_web3,
        shared_config,
        redis_inst,
        eth_abi_values,
    )

    # Clear last scanned redis block on startup
    delete_last_scanned_eth_block_redis(redis_inst)

    # Initialize Anchor Indexer
    anchor_program_indexer = AnchorProgramIndexer(
        shared_config["solana"]["anchor_data_program_id"],
        shared_config["solana"]["anchor_admin_storage_public_key"],
        "index_solana_user_data",
        redis_inst,
        db,
        solana_client_manager,
        cid_metadata_client,
    )

    # Clear existing locks used in tasks if present
    redis_inst.delete("disc_prov_lock")
    redis_inst.delete("network_peers_lock")
    redis_inst.delete("materialized_view_lock")
    redis_inst.delete("update_metrics_lock")
    redis_inst.delete("update_play_count_lock")
    redis_inst.delete("index_hourly_play_counts_lock")
    redis_inst.delete("ipld_blacklist_lock")
    redis_inst.delete("update_discovery_lock")
    redis_inst.delete("aggregate_metrics_lock")
    redis_inst.delete("synchronize_metrics_lock")
    redis_inst.delete("solana_plays_lock")
    redis_inst.delete("index_challenges_lock")
    redis_inst.delete("user_bank_lock")
    redis_inst.delete("index_eth_lock")
    redis_inst.delete("index_oracles_lock")
    redis_inst.delete("solana_rewards_manager_lock")
    redis_inst.delete("calculate_trending_challenges_lock")
    redis_inst.delete("index_user_listening_history_lock")
    redis_inst.delete("prune_plays_lock")
    redis_inst.delete("update_aggregate_table:aggregate_user_tips")
    redis_inst.delete(INDEX_REACTIONS_LOCK)

    logger.info("Redis instance initialized!")

    # Initialize custom task context with database object
    class WrappedDatabaseTask(DatabaseTask):
        def __init__(self, *args, **kwargs):
            DatabaseTask.__init__(
                self,
                db=db,
                web3=web3,
                abi_values=abi_values,
                eth_abi_values=eth_abi_values,
                shared_config=shared_config,
                cid_metadata_client=cid_metadata_client,
                redis=redis_inst,
                eth_web3_provider=eth_web3,
                solana_client_manager=solana_client_manager,
                challenge_event_bus=setup_challenge_bus(),
                anchor_program_indexer=anchor_program_indexer,
            )

    celery.autodiscover_tasks(["src.tasks"], "index", True)

    # Subclassing celery task with discovery provider context
    # Provided through properties defined in 'DatabaseTask'
    celery.Task = WrappedDatabaseTask

    celery.finalize()
def index_listen_count_milestones(db: SessionManager, redis: Redis):
    logger.info(
        "index_listen_count_milestones.py | Start calculating listen count milestones"
    )
    latest_plays_slot = redis.get(latest_sol_plays_slot_key)
    job_start = time.time()
    with db.scoped_session() as session:
        current_play_indexing = get_json_cached_key(redis,
                                                    CURRENT_PLAY_INDEXING)
        if not current_play_indexing or current_play_indexing["slot"] is None:
            return

        check_track_ids = get_track_listen_ids(redis)

        # Pull off current play indexed slot number from redis
        # Pull off track ids to check from redis
        existing_milestone = (session.query(
            Milestone.id, func.max(Milestone.threshold)).filter(
                Milestone.name == LISTEN_COUNT_MILESTONE,
                Milestone.id.in_(check_track_ids),
            ).group_by(Milestone.id).all())

        aggregate_play_counts = (session.query(
            AggregatePlays.play_item_id,
            AggregatePlays.count,
        ).filter(AggregatePlays.play_item_id.in_(check_track_ids)).all())

        milestones = dict(existing_milestone)
        play_counts = dict(aggregate_play_counts)

        # Bulk fetch track's next milestone threshold
        listen_milestones = []
        for track_id in check_track_ids:
            current_milestone = None
            if track_id not in play_counts:
                continue
            if track_id in milestones:
                current_milestone = milestones[track_id]
            next_milestone_threshold = get_next_track_milestone(
                play_counts[track_id], current_milestone)
            if next_milestone_threshold:
                listen_milestones.append(
                    Milestone(
                        id=track_id,
                        threshold=next_milestone_threshold,
                        name=LISTEN_COUNT_MILESTONE,
                        slot=current_play_indexing["slot"],
                        timestamp=datetime.utcfromtimestamp(
                            int(current_play_indexing["timestamp"])),
                    ))

        if listen_milestones:
            session.bulk_save_objects(listen_milestones)

        redis.set(PROCESSED_LISTEN_MILESTONE, current_play_indexing["slot"])
        if check_track_ids:
            redis.srem(TRACK_LISTEN_IDS, *check_track_ids)

    job_end = time.time()
    job_total = job_end - job_start
    logger.info(
        f"index_listen_count_milestones.py | Finished calculating trending in {job_total} seconds",
        extra={
            "job": "index_listen_count_milestones",
            "total_time": job_total
        },
    )
    if latest_plays_slot is not None:
        redis.set(latest_sol_listen_count_milestones_slot_key,
                  int(latest_plays_slot))
def test_migration_idempotency():
    """
    Test the migrations are idempotent -- we can re-run them and they
    succeed. This is a useful test in making sure that during service upgrade
    a service provider may retry a migration multiple times.

    Because not all migrations are historically idempotent, this checking begins at
    the migration following START_MIGRATION
    """

    # Drop DB, ensuring migration performed at start
    if database_exists(DB_URL):
        drop_database(DB_URL)

    create_database(DB_URL)
    session_manager = SessionManager(DB_URL, {})

    # Run db migrations because the db gets dropped at the start of the tests
    alembic_dir = os.getcwd()
    alembic_config = alembic.config.Config(f"{alembic_dir}/alembic.ini")
    alembic_config.set_main_option("sqlalchemy.url", str(DB_URL))
    alembic_config.set_main_option("mode", "test")

    buf = steal_stdout()
    alembic_config.stdout = buf

    # Alembic commands print out instead of returning...
    alembic.command.history(alembic_config)
    # Rows of this output look like
    # b3084b7bc025 -> 5add54e23282, add stems support
    versions = buf.getvalue().decode("utf-8")

    def get_version(line):
        m = re.search("(?P<old_version>.{12}) -> (?P<new_version>.{12}).*",
                      line.strip())
        if m:
            return m.group("new_version")
        return None

    versions = list(filter(None, map(get_version, versions.split("\n"))))
    # Ordered (chronological) list of all alembic revisions
    versions_in_chronological_order = list(reversed(versions))

    alembic_config.stdout = sys.stdout

    # Find migration to start at
    start_index = 0
    for i in range(len(versions_in_chronological_order)):
        if versions_in_chronological_order[i] == START_MIGRATION:
            start_index = i
            break

    # Apply migrations 1 by 1, each time resetting the stored alembic version
    # in the database and replaying the migration twice to test it's idempotency
    prev_version = START_MIGRATION
    for version in versions_in_chronological_order[start_index:]:
        print(f"Running migration {version}")
        alembic.command.upgrade(alembic_config, version)

        # Revert to prev_version
        with session_manager.scoped_session() as session:
            session.execute(
                sqlalchemy.text(
                    f"UPDATE alembic_version SET version_num = '{prev_version}' WHERE version_num = '{version}'"
                ))

        print(f"Running migration {version}")
        alembic.command.upgrade(alembic_config, version)

        prev_version = version

    drop_database(DB_URL)