Example #1
0
 def test_blending_2_2(self):
     q1 = Select(
         columns=[
             column('ad_id'),
             column('impressions'),
             column(HUSKY_QUERY_DATA_SOURCE_COLUMN_NAME)
         ],
         from_obj=table('table1'),
     )
     df1 = Dataframe(
         q1, get_mocked_dataframe_columns_map(['ad_id', 'impressions']),
         set(), {'SF'})
     q2 = Select(
         columns=[
             column('ad_id'),
             column('campaign_id'),
             column('impressions'),
             column(HUSKY_QUERY_DATA_SOURCE_COLUMN_NAME),
         ],
         from_obj=table('table2'),
     )
     df2 = Dataframe(
         q2,
         get_mocked_dataframe_columns_map(
             ['ad_id', 'impressions', 'campaign_id']), set(), {'SF'})
     blended_df = blend_dataframes(SNOWFLAKE_HUSKY_CONTEXT, [df1, df2])
     self.write_test_expectations('query.sql',
                                  compile_query(blended_df.query))
     expected_query = self.read_test_expectations('query.sql')
     self.assertEqual(expected_query, compile_query(blended_df.query))
     self.assertEqual({'ad_id', 'impressions', 'campaign_id'},
                      set(blended_df.slug_to_column.keys()))
Example #2
0
def test_user_insert(db_test_client):
    from application import db
    db.insert_user("oppilas", "oppilas", "Tessa", "Testaaja")
    with pytest.raises(IntegrityError):
        db.insert_user("oppilas", "oppilas", "Tessa", "Testaaja")
    j = db.account.join(db.role)
    sql = Select([func.count(db.account.c.username),
                  db.role.c.name]).select_from(j)
    with db.engine.connect() as conn:
        rs = conn.execute(sql)
        row = rs.first()
        count = row[0]
        role = row[1]
        assert 1 == count
        assert "USER" == role
    db.insert_user("opettaja", "opettaja", "Essi", "Esimerkki", role="TEACHER")
    sql = Select([func.count(db.account.c.username), db.role.c.name
                  ]).select_from(j).where(db.role.c.name == "TEACHER")
    with db.engine.connect() as conn:
        rs = conn.execute(sql)
        row = rs.first()
        count = row[0]
        role = row[1]
        assert 1 == count
        assert "TEACHER" == role

    student = db.get_user_by_id(1)
    teacher = db.get_user_by_id(2)
    null = db.get_user_by_id(3)
    assert student.name == "oppilas"
    assert teacher.name == "opettaja"
    assert null == None
Example #3
0
def limit_query(query: Select, limit: Optional[str] = None, offset: Optional[str] = None) -> Select:
    if limit:
        validators.raise_if_not_int(limit)
        query = query.limit(int(limit))
    if offset:
        validators.raise_if_not_int(offset)
        query = query.offset(int(offset))
    return query
Example #4
0
def paginate(select: Select, page_position: PagePosition) -> Select:
    if page_position.sort:
        order_by_clauses = [
            _sort_direction_map[order.direction](order.field)
            for order in page_position.sort.orders
        ]
        select = select.order_by(*order_by_clauses)
    return select.limit(page_position.limit).offset(page_position.offset)
 def get_my_boards_filter(cls,
                          request: Request,
                          query: Select,
                          value: bool = False) -> Select:
     if value:
         query = query.where(cls.board_alias.c.author_id == request.user.id)
     else:
         query = query.where(cls.board_alias.c.author_id != request.user.id)
     return query
Example #6
0
    def _apply_ip_or_cidr_filter(
        self, stmt: Select, ip_or_cidr: str
    ) -> Select:
        """Apply an appropriate filter for an IP or CIDR block.

        Notes
        -----
        If there is ever a need to support a database that does not have
        native CIDR membership queries, fallback code (probably using a LIKE
        expression) will need to be added here.
        """
        if "/" in ip_or_cidr:
            return stmt.where(text(":c >> ip_address")).params(c=ip_or_cidr)
        else:
            return stmt.where(TokenChangeHistory.ip_address == str(ip_or_cidr))
Example #7
0
def test_execute_select_process_result_value(mocked_client, mocker) -> None:
    mocked_client.execute_statement.return_value = {
        'numberOfRecordsUpdated':
        0,
        'records': [[{
            'longValue': 1
        }, {
            'stringValue': 'cat'
        }]],
        'columnMetadata': [
            {
                "arrayBaseColumnType": 0,
                "isAutoIncrement": False,
                "isCaseSensitive": False,
                "isCurrency": False,
                "isSigned": True,
                "label": "id",
                "name": "id",
                "nullable": 1,
                "precision": 11,
                "scale": 0,
                "schemaName": "",
                "tableName": "pets",
                "type": 4,
                "typeName": "INT",
            },
            {
                "arrayBaseColumnType": 0,
                "isAutoIncrement": False,
                "isCaseSensitive": False,
                "isCurrency": False,
                "isSigned": False,
                "label": "name",
                "name": "name",
                "nullable": 1,
                "precision": 255,
                "scale": 0,
                "schemaName": "",
                "tableName": "pets",
                "type": 12,
                "typeName": "VARCHAR",
            },
        ],
    }
    data_api = DataAPI(
        resource_arn='arn:aws:rds:dummy',
        secret_arn='dummy',
        database='test',
        client=mocked_client,
    )
    assert list(data_api.execute(Select([Pets]))[0]) == [1, 'my_type_cat']
    assert mocked_client.execute_statement.call_args == mocker.call(
        continueAfterTimeout=True,
        database='test',
        includeResultMetadata=True,
        resourceArn='arn:aws:rds:dummy',
        secretArn='dummy',
        sql="""SELECT pets.id, pets.name 
FROM pets""",
    )
Example #8
0
 async def get_exercises(workout_id):
     w = WorkoutExercise.__table__
     e = Exercise.__table__
     query = Select(columns=[*w.c, e.c.name.label('exercise_name')]) \
         .select_from(w.join(e)) \
         .where(w.c.workout_id == workout_id)
     return await db.fetch_all(query)
Example #9
0
def apply_join(query: Select, table: Table, join_table: Table, join: TableJoin):
    """
    Performs a inner or outer join between two tables on a given query object.

    TODO: enable multiple joins

    :param query: A SQLAlchemy select object.
    :param table: The Table we are joining from.
    :param join_table: The Table we are joining to.
    :param join: The Join object describing how to join the tables.
    :return: A SQLAlchemy select object modified to join two tables.
    """
    error_msg = 'Invalid join, "{}" is not a column on table "{}"'
    join_conditions = []

    for column_pair in join.column_pairs:
        from_col = table.columns.get(column_pair.from_column)
        to_col = join_table.columns.get(column_pair.to_column)

        if from_col is None:
            raise ValueError(error_msg.format(column_pair.from_column, table.name))

        if to_col is None:
            raise ValueError(error_msg.format(column_pair.to_column, join_table.name))

        join_conditions.append(from_col == to_col)

    return query.select_from(table.join(join_table, onclause=and_(*join_conditions), isouter=join.outer_join))
Example #10
0
    def get_search_clause(self, table: sa.Table, query: Select, search: str,
                          search_fields: Sequence[str]) -> Select:
        if not search:
            return query

        columns = [getattr(table.c, col) for col in search_fields]
        return query.where(or_(*(col.ilike(f"%{search}%") for col in columns)))
Example #11
0
def paginate(
    query: Select,
    columns: Dict[Column, Any],
    order: Order,
    limit: int,
) -> Select:
    orderer = get_orderer(order)
    comparator = get_comparator(order)

    for column, value in columns.items():
        query = query.order_by(orderer(column))

        if value is not None:
            query = query.where(comparator(column, value))

    return query.limit(limit)
def callPGLoadPayloadTuplesBlocking(dbSessionCreator: DbSessionCreator,
                                    sql: Select,
                                    sqlCoreLoadTupleClassmethod: Callable,
                                    payloadFilt: Optional[Dict] = None,
                                    fetchSize=50) -> LoadPayloadTupleResult:
    payloadFileJson = ujson.dumps(payloadFilt if payloadFilt else {})

    sqlStr = str(
        sql.compile(dialect=postgresql.dialect(),
                    compile_kwargs={"literal_binds": True}))

    loaderModuleClassMethodStr = '.'.join([
        sqlCoreLoadTupleClassmethod.__self__.__module__,
        sqlCoreLoadTupleClassmethod.__self__.__name__,
        sqlCoreLoadTupleClassmethod.__name__
    ])

    session = dbSessionCreator()
    try:
        sqlFunc = func.peek_storage.load_paylaod_tuples(
            sqlStr, payloadFileJson, loaderModuleClassMethodStr,
            __sysPathsJson, fetchSize)

        resultJsonStr: str = next(session.execute(sqlFunc))[0]

        resultJson: Dict = ujson.loads(resultJsonStr)
        if resultJson["encodedPayload"]:
            resultJson["encodedPayload"] = resultJson["encodedPayload"].encode(
            )

        return LoadPayloadTupleResult(**resultJson)

    finally:
        session.close()
 async def list(user_id: str):
     e = Exercise.__table__
     w = Workout.__table__
     query = Select(columns=[*e.c, w.c.date.label('last_workout_date')]) \
         .select_from(e.outerjoin(w)) \
         .where(e.c.user_id == user_id) \
         .where(e.c.is_deleted == false()) \
         .order_by(nullslast(desc(w.c.date)))
     return await db.fetch_all(query)
Example #14
0
    def compile_sqla_query(self, qry: Select, schema: Optional[str] = None) -> str:
        engine = self.get_sqla_engine(schema=schema)

        sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))

        # pylint: disable=protected-access
        if engine.dialect.identifier_preparer._double_percents:  # noqa
            sql = sql.replace("%%", "%")

        return sql
 async def view_exercise_history(exercise_id: str):
     we = WorkoutExercise.__table__
     workouts = Workout.__table__
     query = Select(columns=[*we.c,
                             workouts.c.date.label('workout_date'),
                             workouts.c.id.label('workout_id')]) \
         .select_from(we.join(workouts)) \
         .where(workouts.c.is_deleted == false()) \
         .where(we.c.exercise_id == exercise_id) \
         .order_by(desc(workouts.c.date))
     return await db.fetch_all(query)
Example #16
0
 async def check_entity_belongs_to_user(table: Table, entity_id: str,
                                        user_id: str):
     query = Select(columns=[func.count().label('cnt')]) \
         .select_from(table) \
         .where(table.c.user_id == user_id) \
         .where(table.c.is_deleted == false()) \
         .where(table.c.id == entity_id)
     row = await db.fetch_one(query)
     result = dict(row)['cnt'] != 0 if row is not None else False
     if not result:
         raise CustomException('Недостаточно прав для просмотра записи')
Example #17
0
def create_single_query_mock(data_source_name):
    """
    convenience fn to create sqlalchemy's Select clausewith some column and table.
    """
    return Select(
        columns=[
            column(data_source_name + '_column_mock'),
            column(HUSKY_QUERY_DATA_SOURCE_COLUMN_NAME)
        ],
        from_obj=text(data_source_name + '_table_mock'),
    )
Example #18
0
    async def filter_queryset(self, queryset: Select) -> Select:
        validated_data = self.request['validated_data']
        group_name = validated_data['group']
        version = validated_data.get('version')

        filters = [
            switches.c.groups.contains(f'{{{group_name}}}'),
            switches.c.is_active == true(),
            switches.c.is_hidden == false(),
        ]
        if version is not None:
            filters.append(switches.c.version <= version)

        return queryset.where(and_(*filters))
Example #19
0
 def _apply_cursor(stmt: Select, cursor: HistoryCursor) -> Select:
     """Apply a cursor to a query."""
     time = datetime_to_db(cursor.time)
     if cursor.previous:
         return stmt.where(
             or_(
                 TokenChangeHistory.event_time > time,
                 and_(
                     TokenChangeHistory.event_time == time,
                     TokenChangeHistory.id > cursor.id,
                 ),
             )
         )
     else:
         return stmt.where(
             or_(
                 TokenChangeHistory.event_time < time,
                 and_(
                     TokenChangeHistory.event_time == time,
                     TokenChangeHistory.id <= cursor.id,
                 ),
             )
         )
Example #20
0
def main(**kwargs):
    # use ti for xcom push and pull
    ti = kwargs["ti"]
    # try to read user buffer csv
    # if not found set param_id = 0
    user_id = 0
    buffer_dir = "/Users/muhammadsyamsularifin/airflow/buffer_data/users.csv"
    try:
        df = pd.read_csv(buffer_dir)
        user_id = int(df.tail(1)["id"])
    except FileNotFoundError:
        pass
    
    # select data from db, with id greater than variable "user_id"
    query = Select([Users]).where(Users.c.id>user_id).limit(3)
    result = conn.execute(query)

    # prepare pandas dataframe, new and empty
    column = [
        "id",
        "username",
        "address",
        "is_active",
        "domicile",
        "balance",
        "point"
    ]
    df = pd.DataFrame(columns=column)

    # insert it into pandas
    for id, username, address, is_active, domicile, balance, point in result:
        new_row = {
            "id": id,
            "username": username,
            "address": address,
            "is_active": is_active,
            "domicile": domicile,
            "balance": float(balance),
            "point": float(point)
        }
        df = df.append(new_row, ignore_index=True)

    # if no data extracted, tell xcom that it is done
    if (len(df["id"]) == 0):
        ti.xcom_push(key="extract_user_done", value=1)
        return None

    # save to csv file
    buffer_dir = "/Users/muhammadsyamsularifin/airflow/buffer_data/users.csv"
    df.to_csv(buffer_dir, index=False)
Example #21
0
    def visit(self):
        self._expr['selectList']['key'] = 'selectList'
        clause = Select(self._expr['selectList']).visit()
        self._expr['selectList'].pop('key')

        for key, expr in self._expr.items():
            if key == 'key':
                continue
            if key == 'selectList':
                continue
            expr['key'] = key
            clause = self.access(clause, expr)
            expr.pop('key')
        return clause
Example #22
0
def test_db_file_insert_constraint(db_test_client):
    from application import db
    s="nothing"
    db.insert_user(s,s,s,s,role="TEACHER")
    user_id = db.get_user(s,s).get_id()

    with pytest.raises(IntegrityError):
        file_insert_helper(db, user_id=user_id, binary_file=io.BytesIO(b"helvetin turhia bitteja").read())

    sql = Select([db.file])
    with db.engine.connect() as conn:
        rs = conn.execute(sql)
        row = rs.first()
        assert row is None
Example #23
0
    def _project_columns(
        cls, query: Select, dataframe: Dataframe,
        return_taxons: Dict[TaxonExpressionStr, Taxon]
    ) -> Tuple[List[ColumnAndDataframeColumn], Select]:
        projected_sql_and_df_columns: List[ColumnAndDataframeColumn] = [
            cls._project_column(
                query, taxon,
                dataframe.slug_to_column.get(taxon_slug_expression))
            for taxon_slug_expression, taxon in return_taxons.items()
        ]

        return (
            projected_sql_and_df_columns,
            Select(columns=sort_columns(
                [col for col, _ in projected_sql_and_df_columns])),
        )
Example #24
0
def _apply_limit(query: Select,
                 args: dict) -> Select:
    """
    If a limit has been supplied by the 'length' parameter (and an offset
    supplied by the 'start' parameter) from a DataTable AJAX request,
    this adds it to the given query. A length of -1 represents no limit or
    offset. The offset index is 0-based.

    :param query: the DataTable SQL query
    :param args: the query parameters from a DataTable AJAX request
    :return: the query with the limit (if it exists) applied
    """
    limit = int(args['length'])
    if limit != -1:
        offset = int(args.get('start', 0))
        query = query.limit(limit).offset(offset)
    return query
Example #25
0
async def paginate(
    db: Database,
    query: Select,
    params: Optional[AbstractParams] = None,
    *,
    convert_to_mapping: bool = True,
) -> AbstractPage:
    params = resolve_params(params)

    total = await db.fetch_val(
        select([func.count()]).select_from(query.alias()))
    items = await db.fetch_all(paginate_query(query, params))

    if convert_to_mapping:
        items = [{**item} for item in items]

    return create_page(items, total, params)
Example #26
0
def _apply_text_filter(query: Select,
                       args: dict,
                       column: Column) -> Select:
    """
    If a value has been specified by the ['search']['value'] parameter from
    a DataTable AJAX request, this adds "column LIKE %['search']['value']%"
    to the WHERE clause of the given query.

    :param query: the DataTable SQL query
    :param args: the query parameters from a DataTable AJAX request
    :param column: the column to be filtered
    :return: the query with the filter (if it exists) applied
    """
    search_value = args['search']['value']
    if search_value:
        query = query.where(
            column.like('%{}%'.format(search_value))
        )
    return query
Example #27
0
def _paginate(
    query: Select,
    session: Session,
    total_items: int,
    offset: int,
    limit: int,
) -> Page[T]:
    total_pages = math.ceil(total_items / limit)
    page_number = offset / limit + 1
    query = query.offset(offset).limit(limit)
    result = session.execute(query)
    return Page[T](
        data=iter(result.unique().scalars()),
        meta={
            "offset": offset,
            "total_items": total_items,
            "total_pages": total_pages,
            "page_number": page_number,
        },
    )
Example #28
0
async def paginate(query: Select,
                   params: Optional[AbstractParams] = None) -> AbstractPage:
    params = resolve_params(params)

    try:
        is_loader_used = query._execution_options[
            "loader"]._distinct  # type: ignore
    except (AttributeError, KeyError):
        is_loader_used = False

    if is_loader_used:
        # FIXME: find better way to fetch rows when loader is used
        items = await query.gino.all()  # type: ignore
        return base_paginate(items, params)

    total = await func.count().select().select_from(query.alias()
                                                    ).gino.scalar()
    items = await paginate_query(query, params).gino.all()  # type: ignore

    return create_page(items, total, params)
    def calculate_dataframe(
        cls,
        dimension_formulas: List[PreFormula],
        override_mappings_tel_data: OverrideMappingTelData,
        override_mapping_cte_map: Dict[OverrideMappingSlug, Select],
        df: Dataframe,
    ) -> Dataframe:
        select_columns = []
        select_columns.extend(df.query.columns)
        for dim_formula in dimension_formulas:
            col = dim_formula.formula.label(dim_formula.label)
            select_columns.append(col)

        # add joins to relevant override mapping CTEs
        select_from_query = OverrideMappingSql.insert_cte_joins(
            df.query, override_mappings_tel_data, override_mapping_cte_map)

        query = Select(columns=sort_columns(select_columns)).select_from(
            select_from_query)
        return Dataframe(query, df.slug_to_column, df.used_model_names,
                         df.used_physical_data_sources)
Example #30
0
    def paginate(
            self,
            query: Select,
            item_factory: Callable[[Row], ModelT],
            *,
            sort_model: Base = None,
            custom_sort: str = None,
    ) -> Page[ModelT]:
        total = Session.execute(
            select(func.count()).
            select_from(query.subquery())
        ).scalar_one()

        try:
            if sort_model:
                sort_col = getattr(sort_model, self.sort)
            elif custom_sort:
                sort_col = text(custom_sort)
            else:
                sort_col = self.sort

            limit = self.size or total

            items = [
                item_factory(row) for row in Session.execute(
                    query.
                    order_by(sort_col).
                    offset(limit * (self.page - 1)).
                    limit(limit)
                )
            ]
        except (AttributeError, CompileError):
            raise HTTPException(HTTP_422_UNPROCESSABLE_ENTITY, 'Invalid sort column')

        return Page(
            items=items,
            total=total,
            page=self.page,
            pages=ceil(total / limit) if limit else 0,
        )
Example #31
0
def _apply_ordering(query: Select,
                    args: dict,
                    default_sort_column_name: str,
                    default_direction: str='DESC') -> Select:
    """
    If an ordering has been supplied by the ['order'] parameter from a
    DataTable AJAX request, this adds it to the given query. Otherwise,
    it adds the ordering specified by the default_sort_column_name and
    default_direction.

    :param query: the DataTable SQL query
    :param args: the query parameters from a DataTable AJAX request
    :param default_sort_column_name: the name of the column for the default
                                     ordering
    :param default_direction: the default sort direction (DESC if not
                              specified)
    :return: the query with the ordering applied
    """
    ord = args['order']
    default_sort = '{} {}'.format(default_sort_column_name, default_direction)
    order_by = ', '.join(_get_orderings(args)) if ord else default_sort
    return query.order_by(order_by + ' NULLS LAST')
Example #32
0
    def augment_query(
        cls,
        ctx: HuskyQueryContext,
        query: Select,
        taxon_model_info_map: Dict[str, TaxonModelInfo],
        filter_clause: Optional[FilterClause],
    ) -> Select:
        """
        Adds filters to the query

        :param ctx: Husky query runtime
        :param query: Original query
        :param taxon_model_info_map: Map of taxon slug expression to taxon model info
        :param filter_clause: Filter clauses

        :return: New query with all modifiers applied
        """
        if filter_clause:
            new_q = query.where(filter_clause.generate(ctx, query, taxon_model_info_map))
            return new_q
        else:
            return query
Example #33
0
    def render_direct_mapping(cls, mapping: OverrideMapping) -> Select:
        """Renders CTE for direct mapping as union of all values"""

        selects = []
        for original, changed in mapping.definition:
            # using "literal" instead of "literal_column" here to force SQLAlchemy to bind constants as params (safe)
            if original is None:
                original_column = literal_column('CAST(NULL AS VARCHAR)')
            else:
                original_column = literal(original)

            if changed is None:
                changed_column = literal(cls.PANO_NULL)
            else:
                changed_column = literal(changed)

            selects.append(
                Select([
                    original_column.label(cls.ORIGINAL_COLUMN_NAME),
                    changed_column.label(cls.CHANGED_COLUMN_NAME)
                ]))

        return union_all(*selects)
Example #34
0
def change_objects_ownership(engine: Engine, database: str,
                             target_role: str) -> None:
    stmt = Select(
        [
            literal_column("table_type"),
            literal_column("table_schema"),
            literal_column("table_name"),
        ],
        from_obj=text(f"{database}.INFORMATION_SCHEMA.TABLES"),
        whereclause=literal_column("table_owner") == "DBT_PRODUCTION",
    )

    with engine.begin() as tx:
        rp = tx.execute(stmt)
        objects = [(
            "TABLE" if object_type == "BASE TABLE" else object_type,
            schema,
            object_name,
        ) for object_type, schema, object_name in rp.fetchall()]

        for object_type, schema, object_name in objects:
            tx.execute(
                f"GRANT OWNERSHIP ON {object_type} {database}.{schema}.{object_name} TO ROLE {target_role} REVOKE CURRENT GRANTS"
            ).fetchall()