async def create_pool(cls, pool: Pool) -> "FakeAsyncPGPool":
     pool = cls(pool)
     conn = await pool._pool.acquire()
     tx = conn.transaction()
     await tx.start()
     pool._conn = conn
     pool._tx = tx
     return pool
Beispiel #2
0
async def test_user_can_not_change_following_state_to_the_same_twice(
    app: FastAPI,
    authorized_client: AsyncClient,
    pool: Pool,
    test_user: UserInDB,
    api_method: str,
    route_name: str,
    following: bool,
) -> None:
    async with pool.acquire() as conn:
        users_repo = UsersRepository(conn)
        user = await users_repo.create_user(
            username="******",
            email="*****@*****.**",
            password="******",
        )

        if following:
            profiles_repo = ProfilesRepository(conn)
            await profiles_repo.add_user_into_followers(
                target_user=user, requested_user=test_user)

    response = await authorized_client.request(
        api_method, app.url_path_for(route_name, username=user.username))

    assert response.status_code == status.HTTP_400_BAD_REQUEST
Beispiel #3
0
    async def get_batches(cls,
                          pool: Pool = None,
                          nax_batches: int = 30) -> Dict[str, List[str]]:
        rc: Dict = {}
        if not pool:
            pool = config.db_conn_pool
        async with pool.acquire() as con:
            async with con.transaction():
                rows = await con.fetch(
                    """
                        SELECT batch_id, algo_run_id, algo_name, algo_env, build_number, start_time
                        FROM algo_run
                        ORDER BY start_time DESC
                        LIMIT $1
                    """,
                    nax_batches,
                )

                if rows:
                    for row in rows:
                        if row[0] not in rc:
                            rc[row[0]] = [list(row.values())[1:]]
                        else:
                            rc[row[0]].append(list(row.values())[1:])

        return rc
    async def save(self, pool: Pool) -> bool:
        try:
            async with pool.acquire() as con:
                async with con.transaction():
                    await con.execute(
                        """
                            INSERT INTO ticker_data (symbol, name, description, tags,
                                                     similar_tickers, industry, sector, exchange)
                            VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
                            ON CONFLICT (symbol)
                            DO UPDATE
                                SET name=$2, description=$3, tags=$4, similar_tickers=$5, industry=$6,
                                    sector=$7, exchange=$8, modify_tstamp='now()'
                        """,
                        self.symbol,
                        self.name,
                        self.description,
                        self.tags,
                        self.similar_tickers,
                        self.industry,
                        self.sector,
                        self.exchange,
                    )
            return True
        except TooManyConnectionsError as e:
            tlog(
                f"too many db connections: {e}. Failed to write ticker {self.symbol}, will re-try"
            )

        return False
    async def save(
        self,
        pool: Pool = None,
    ) -> None:
        if not pool:
            pool = config.db_conn_pool

        async with pool.acquire() as con:
            async with con.transaction():
                await con.execute(
                    """
                        INSERT INTO stock_ohlc (symbol, symbol_date, open, high, low, close, volume, indicators)
                        VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
                        ON CONFLICT (symbol, symbol_date)
                        DO UPDATE
                            SET open=$3, high=$4, low=$5, close=$6, volume=$7, indicators=$8, modify_tstamp='now()'
                    """,
                    self.symbol,
                    self.symbol_date,
                    self.open,
                    self.high,
                    self.low,
                    self.close,
                    self.volume,
                    json.dumps(self.indicators),
                )
async def abstract_query(command, query, args=(), pool: Pool = None):
    pool = await PgPool.get_pool()
    async with pool.acquire() as conn:
        async with conn.transaction():
            func = getattr(conn, command)
            result = await func(query, *args)
    return result
Beispiel #7
0
async def test_update_game(pool: Pool, test_auth0_user: Auth0User):
    async with pool.acquire() as conn:
        db_repo = ChessDbRepository(conn)
        game_repo = ChessDbGameRepository(conn)
        db = await db_repo.create_db(name="db1", user=test_auth0_user)
        expected = ChessDbGame(
            date=datetime.date(1999, 1, 1),
            result="1-0",
            event="Some Tournament",
            black="Some guy",
            white="another dude",
            chess_db_id=db.id,
            user_id=test_auth0_user.id,
            id=11
        )
        expected = await game_repo.create_game(db_id=expected.chess_db_id, user_id=expected.user_id,
                                               white=expected.white,
                                               black=expected.black, event=expected.event, date=expected.date,
                                               result=expected.result)
        expected.date = datetime.date(2000, 1, 2)
        expected.result = "*"
        expected.event = "Madrid"
        expected.black = "A player"
        expected.white = "another player"
        actual = await game_repo.update_game(game=expected)
        assert actual == expected
        assert len(await game_repo.find_by_db_id(db_id=db.id)) == 1
Beispiel #8
0
async def get_data_from_db(uuid: UUID, pool: Pool) -> Union[bool, List]:
    """
    Поиск локального uuid и возвращение данных из базы.

    :param UUID uuid: - локальный uuid, по которому ищутся данные в базе.
    :param Pool pool: - пул подключения к базе данных.
    :return List: - список с данными из базы данных.
    """
    if uuid is None or pool is None:
        return False
    try:
        j = parse_event_table.join(
            names_table, parse_event_table.c.primary_names_id == names_table.c.id
        )
        query_data = parse_event_table.select().with_only_columns(
            [
                parse_event_table.c.bankrupt_name,
                parse_event_table.c.title,
                parse_event_table.c.number,
                parse_event_table.c.text,
                parse_event_table.c.data_publish,
                parse_event_table.c.guid,
            ]
        ).select_from(j).where(
            names_table.c.local_uuid == uuid
        ).where(
            parse_event_table.c.type == 'BankruptcyMessage'
        )
        async with pool.acquire() as conn:
            values = ListEvent(data=await conn.fetch(query_data))
        return values

    except Exception as error:
        return False
Beispiel #9
0
async def get_prices(
    dataset_name: SessionDatasetNames,
    session_id,
    request_params: DataRequestSchema,
    pool: Pool,
) -> List[PricepointSchema]:
    """Get prices."""
    if not request_params.to_datetime:
        request_params.to_datetime = datetime.now()
    async with pool.acquire() as conn:
        query = f"""
        SELECT
            -- time_bucket_gapfill($1, timestamp, now() - INTERVAL '2 hours', now()) AS time,
            time_bucket_gapfill($1, timestamp) AS time,
            locf(avg(price)) as price,
            sum(volume) as volume
        FROM {dataset_name}_ticks
        WHERE session_id = $2
        and label = $3
        and data_type = $4
        and timestamp BETWEEN $5 and $6
        GROUP BY time
        ORDER BY time ASC;
        """
        params = (
            timedelta(minutes=request_params.period),
            session_id,
            request_params.label,
            request_params.data_type,
            request_params.from_datetime,
            request_params.to_datetime,
        )
        result = list(await conn.fetch(query, *params))
        return parse_obj_as(List[PricepointSchema], result)
Beispiel #10
0
async def test_user_can_change_following_for_another_user(
    app: FastAPI,
    authorized_client: AsyncClient,
    pool: Pool,
    test_user: UserInDB,
    api_method: str,
    route_name: str,
    following: bool,
) -> None:
    async with pool.acquire() as conn:
        users_repo = UsersRepository(conn)
        user = await users_repo.create_user(
            username="******",
            email="*****@*****.**",
            password="******",
        )

        if not following:
            profiles_repo = ProfilesRepository(conn)
            await profiles_repo.add_user_into_followers(
                target_user=user, requested_user=test_user)

    change_following_response = await authorized_client.request(
        api_method, app.url_path_for(route_name, username=user.username))
    assert change_following_response.status_code == status.HTTP_200_OK

    response = await authorized_client.get(
        app.url_path_for("profiles:get-profile", username=user.username))
    profile = ProfileInResponse(**response.json())
    assert profile.profile.username == user.username
    assert profile.profile.following == following
async def test_create_db_game(authorized_client: AsyncClient, app: FastAPI,
                              test_auth0_user: Auth0User, pool: Pool):
    db_id = (await post_create_db(authorized_client, app, "db1")).json()["id"]
    json = {
        "game": {
            "white": "a player",
            "black": "another player",
            "event": "London Classic",
            "date": datetime.date(2018, 6, 12).strftime("%Y-%m-%d"),
            "result": "0-1"
        }
    }
    res = await authorized_client.post(app.url_path_for("db:create_game",
                                                        db_id=db_id),
                                       json=json)
    assert res.status_code == HTTP_201_CREATED
    g = res.json()
    assert g["white"] == json["game"]["white"]
    assert g["black"] == json["game"]["black"]
    assert g["event"] == json["game"]["event"]
    assert g["date"] == json["game"]["date"]
    assert g["result"] == json["game"]["result"]
    assert g["userId"] == test_auth0_user.id
    assert g["chessDbId"] == db_id
    assert g["id"] >= 0
    async with pool.acquire() as connection:
        repository = ChessDbGameRepository(connection)
        assert len(await repository.find_by_db_id(db_id=db_id)) == 1
Beispiel #12
0
    async def load_by_date(cls,
                           symbol_date: date,
                           pool: Pool = None) -> Dict[str, object]:
        if not pool:
            pool = config.db_conn_pool

        async with pool.acquire() as con:
            async with con.transaction():
                rows = await con.fetch(
                    """
                        SELECT symbol, symbol_date, open, high, low, close, volume, indicators
                        FROM stock_ohlc
                        WHERE symbol_date = $1 
                    """,
                    symbol_date,
                )

                rc: Dict[str, object] = {}

                for x in rows:
                    rc[x[0]] = StockOhlc(
                        symbol=x[0],
                        symbol_date=x[1],
                        open=x[2],
                        high=x[3],
                        low=x[4],
                        close=x[5],
                        volume=x[6],
                        indicators=json.loads(x[7]),
                    )

                return rc
Beispiel #13
0
    async def save(
        self,
        pool: Pool,
        client_buy_time: str,
        stop_price=None,
        target_price=None,
    ):
        async with pool.acquire() as con:
            async with con.transaction():
                try:
                    indicators_s = json.dumps(self.indicators or {},
                                              allow_nan=False)
                except ValueError:
                    indicators_s = json.dumps({})

                self.trade_id = await con.fetchval(
                    """
                        INSERT INTO new_trades (algo_run_id, symbol, operation, qty, price, indicators, client_time, stop_price, target_price)
                        VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
                        RETURNING trade_id
                    """,
                    self.algo_run_id,
                    self.symbol,
                    self.operation,
                    self.qty,
                    self.price,
                    indicators_s,
                    client_buy_time,
                    stop_price,
                    target_price,
                )
Beispiel #14
0
async def table_index(db_pool: Pool) -> Sequence:
    """Fetch Table index."""
    async with db_pool.acquire() as conn:
        q = await conn.prepare(sql_query)
        content = await q.fetchval()

    return json.loads(content)
Beispiel #15
0
    async def set_max_sequence(self, dst_pool: Pool):
        async with dst_pool.acquire() as connection:
            try:
                get_serial_sequence_sql = SQLRepository.get_serial_sequence_sql(
                    table_name=self.name,
                    pk_column_name=self.primary_key.name,
                )
            except AttributeError:
                logger.warning(
                    f'AttributeError --- {self.name} --- set_max_sequence')
                return

            serial_seq_name = await connection.fetchrow(get_serial_sequence_sql
                                                        )

            if serial_seq_name and serial_seq_name[0]:
                serial_seq_name = serial_seq_name[0]

                max_val = self.max_pk + 100000

                set_sequence_val_sql = (
                    SQLRepository.get_set_sequence_value_sql(
                        seq_name=serial_seq_name,
                        seq_val=max_val,
                    ))

                await connection.execute(set_sequence_val_sql)
Beispiel #16
0
    async def get_batch_list_by_date(
        cls,
        batch_date: date,
        pool: Pool = None,
    ) -> Dict[str, List[str]]:
        rc: Dict = {}
        if not pool:
            pool = config.db_conn_pool
        async with pool.acquire() as con:
            async with con.transaction():
                rows = await con.fetch(
                    """
                        SELECT batch_id, algo_run_id, algo_name, algo_env, build_number, start_time
                        FROM algo_run
                        ORDER BY start_time DESC
                        WHERE start_time >= $1 and start_time < $2
                    """,
                    batch_date,
                    batch_date + timedelta(days=1),
                )

                if rows:
                    for row in rows:
                        if row[0] not in rc:
                            rc[row[0]] = [list(row.values())[1:]]
                        else:
                            rc[row[0]].append(list(row.values())[1:])

        return rc
Beispiel #17
0
    async def get_batch_details(
        cls, batch_id: str, pool: Pool = None
    ) -> List[Tuple[int, datetime, datetime, str]]:
        rc: List = []
        if not pool:
            pool = config.db_conn_pool
        async with pool.acquire() as con:
            async with con.transaction():
                rows = await con.fetch(
                    """
                        SELECT algo_run_id, start_time, end_time, parameters, algo_name
                        FROM algo_run
                        WHERE batch_id = $1
                        ORDER BY start_time DESC
                    """,
                    batch_id,
                )

                if rows:
                    rc = [
                        (
                            row[0],
                            row[1],
                            row[2],
                            row[3],
                        )
                        for row in rows
                    ]

        return rc
Beispiel #18
0
    async def get_batches(
            cls,
            pool: Pool = None,
            nax_batches: int = 30,
            start_date: date = date(2019, 1, 1),
    ) -> List:
        rc: Dict = {}
        if not pool:
            pool = config.db_conn_pool
        async with pool.acquire() as con:
            async with con.transaction():
                rows = await con.fetch(
                    """
                        SELECT build_number, batch_id, algo_name, algo_env, date_trunc('minute', min(start_time)) as start
                        FROM algo_run
                        WHERE start_time > $2
                        GROUP BY batch_id, algo_name, algo_env, build_number
                        ORDER BY start DESC
                        LIMIT $1
                    """,
                    nax_batches,
                    start_date,
                )

                if rows:
                    return [list(map(str, row.values())) for row in rows]

        return []
async def test_empty_feed_if_user_has_not_followings(
    app: FastAPI,
    authorized_client: Client,
    test_article: Article,
    test_user: UserInDB,
    pool: Pool,
) -> None:
    async with pool.acquire() as connection:
        users_repo = UsersRepository(connection)
        articles_repo = ArticlesRepository(connection)

        for i in range(5):
            user = await users_repo.create_user(username=f"user-{i}",
                                                email=f"user-{i}@email.com",
                                                password="******")
            for j in range(5):
                await articles_repo.create_article(
                    slug=f"slug-{i}-{j}",
                    title="tmp",
                    description="tmp",
                    body="tmp",
                    author=user,
                    tags=[f"tag-{i}-{j}"],
                )

    response = await authorized_client.get(
        app.url_path_for("articles:get-user-feed-articles"))

    articles = ListOfArticlesInResponse(**response.json())
    assert articles.articles == []
async def test_article_will_contain_only_attached_tags(
        app: FastAPI, authorized_client: Client, test_user: UserInDB,
        pool: Pool) -> None:
    attached_tags = ["tag1", "tag3"]

    async with pool.acquire() as connection:
        articles_repo = ArticlesRepository(connection)

        await articles_repo.create_article(
            slug=f"test-slug",
            title="tmp",
            description="tmp",
            body="tmp",
            author=test_user,
            tags=attached_tags,
        )

        for i in range(5):
            await articles_repo.create_article(
                slug=f"slug-{i}",
                title="tmp",
                description="tmp",
                body="tmp",
                author=test_user,
                tags=[f"tag-{i}"],
            )

    response = await authorized_client.get(
        app.url_path_for("articles:get-article", slug="test-slug"))
    article = ArticleInResponse(**response.json())
    assert len(article.article.tags) == len(attached_tags)
    assert set(article.article.tags) == set(attached_tags)
async def test_user_can_not_modify_article_that_is_not_authored_by_him(
    app: FastAPI,
    authorized_client: Client,
    pool: Pool,
    api_method: str,
    route_name: str,
) -> None:
    async with pool.acquire() as connection:
        users_repo = UsersRepository(connection)
        user = await users_repo.create_user(username="******",
                                            email="*****@*****.**",
                                            password="******")
        articles_repo = ArticlesRepository(connection)
        await articles_repo.create_article(
            slug="test-slug",
            title="Test Slug",
            description="Slug for tests",
            body="Test " * 100,
            author=user,
            tags=["tests", "testing", "pytest"],
        )

    response = await authorized_client.request(
        api_method,
        app.url_path_for(route_name, slug="test-slug"),
        json={"article": {
            "title": "Updated Title"
        }},
    )
    assert response.status_code == status.HTTP_403_FORBIDDEN
async def test_user_can_change_favorite_state(
    app: FastAPI,
    authorized_client: Client,
    test_article: Article,
    test_user: UserInDB,
    pool: Pool,
    api_method: str,
    route_name: str,
    favorite_state: bool,
) -> None:
    if not favorite_state:
        async with pool.acquire() as connection:
            articles_repo = ArticlesRepository(connection)
            await articles_repo.add_article_into_favorites(
                article=test_article, user=test_user)

    await authorized_client.request(
        api_method, app.url_path_for(route_name, slug=test_article.slug))

    response = await authorized_client.get(
        app.url_path_for("articles:get-article", slug=test_article.slug))

    article = ArticleInResponse(**response.json())

    assert article.article.favorited == favorite_state
    assert article.article.favorites_count == int(favorite_state)
Beispiel #23
0
async def test_user_can_not_take_already_used_credentials(
    app: FastAPI,
    authorized_client: AsyncClient,
    pool: Pool,
    token: str,
    credentials_part: str,
    credentials_value: str,
) -> None:
    user_dict = {
        "username": "******",
        "password": "******",
        "email": "*****@*****.**",
    }
    user_dict.update({credentials_part: credentials_value})
    async with pool.acquire() as conn:
        users_repo = UsersRepository(conn)
        await users_repo.create_user(**user_dict)

    response = await authorized_client.put(
        app.url_path_for("users:update-current-user"),
        json={"user": {
            credentials_part: credentials_value
        }},
    )
    assert response.status_code == status.HTTP_400_BAD_REQUEST
Beispiel #24
0
    async def load_latest(
            cls, pool: Pool, symbol: str, strategy_name: str
    ) -> Tuple[int, float, float, float, Dict, datetime]:
        async with pool.acquire() as con:
            async with con.transaction():
                row = await con.fetchrow(
                    """
                        SELECT t.algo_run_id, t.price, t.stop_price, t.target_price, t.indicators, t.tstamp 
                        FROM new_trades as t, algo_run as a
                        WHERE 
                            t.algo_run_id=a.algo_run_id AND
                            a.algo_name=$2 AND
                            symbol=$1 
                        ORDER BY tstamp DESC LIMIT 1
                    """,
                    symbol,
                    strategy_name,
                )

                if row:
                    return (
                        int(row[0]),
                        float(row[1]),
                        float(row[2]),
                        float(row[3]),
                        json.loads(row[4]),
                        row[5],
                    )
                tlog(f"{symbol} no data for strategy {strategy_name}")
                raise ValueError
async def test_filtering_with_limit_and_offset(app: FastAPI,
                                               authorized_client: Client,
                                               test_user: UserInDB,
                                               pool: Pool) -> None:
    async with pool.acquire() as connection:
        articles_repo = ArticlesRepository(connection)

        for i in range(5, 10):
            await articles_repo.create_article(
                slug=f"slug-{i}",
                title="tmp",
                description="tmp",
                body="tmp",
                author=test_user,
            )

    full_response = await authorized_client.get(
        app.url_path_for("articles:list-articles"))
    full_articles = ListOfArticlesInResponse(**full_response.json())

    response = await authorized_client.get(
        app.url_path_for("articles:list-articles"),
        params={
            "limit": 2,
            "offset": 3
        })

    articles_from_response = ListOfArticlesInResponse(**response.json())
    assert full_articles.articles[3:] == articles_from_response.articles
Beispiel #26
0
async def get_market_sectors(pool: Pool) -> List[str]:
    async with pool.acquire() as conn:
        async with conn.transaction():
            records = await conn.fetch("""
                    SELECT DISTINCT sector
                    FROM ticker_data
                """)
            return [record[0] for record in records if record[0]]
Beispiel #27
0
async def db_fetch(db: Pool, sql, *params):
    async with db.acquire() as conn:
        star = time.time()
        rows = await conn.fetch(sql, *params)
        end = round((time.time() - star) * 1000)
        LOG.info('Found %s rows in %sms for %s parameters', len(rows), end,
                 params)
        return rows
async def test_delete_db(pool: Pool, test_auth0_user: Auth0User):
    async with pool.acquire() as connection:
        repo = ChessDbRepository(connection)
        res = await repo.create_db(name="test", user=test_auth0_user)
        assert len(await repo.get_db_for_user(user_id=test_auth0_user.id)) == 1
        res = await repo.delete_db(db_id=res.id)
        assert res == True
        assert len(await repo.get_db_for_user(user_id=test_auth0_user.id)) == 0
async def test_create_db(pool: Pool, test_auth0_user: Auth0User):
    db_name = "testDb"
    async with pool.acquire() as connection:
        repo = ChessDbRepository(connection)
        db = await repo.create_db(name=db_name, user=test_auth0_user)
        assert db.name == db_name
        assert db.user_id == test_auth0_user.id
        assert db.id >= 0
Beispiel #30
0
async def run_query(query: str, pool: Pool, message_store: MessageStore):
    async with pool.acquire() as connection:
        try:
            result = await connection.fetchrow(query)
            await message_store.append(
                f'Fetched {len(result)} rows from: {query}')
        except Exception as e:
            await message_store.append(f'Got exception {e} from: {query}')