def test_blending_2_2(self): q1 = Select( columns=[ column('ad_id'), column('impressions'), column(HUSKY_QUERY_DATA_SOURCE_COLUMN_NAME) ], from_obj=table('table1'), ) df1 = Dataframe( q1, get_mocked_dataframe_columns_map(['ad_id', 'impressions']), set(), {'SF'}) q2 = Select( columns=[ column('ad_id'), column('campaign_id'), column('impressions'), column(HUSKY_QUERY_DATA_SOURCE_COLUMN_NAME), ], from_obj=table('table2'), ) df2 = Dataframe( q2, get_mocked_dataframe_columns_map( ['ad_id', 'impressions', 'campaign_id']), set(), {'SF'}) blended_df = blend_dataframes(SNOWFLAKE_HUSKY_CONTEXT, [df1, df2]) self.write_test_expectations('query.sql', compile_query(blended_df.query)) expected_query = self.read_test_expectations('query.sql') self.assertEqual(expected_query, compile_query(blended_df.query)) self.assertEqual({'ad_id', 'impressions', 'campaign_id'}, set(blended_df.slug_to_column.keys()))
def test_user_insert(db_test_client): from application import db db.insert_user("oppilas", "oppilas", "Tessa", "Testaaja") with pytest.raises(IntegrityError): db.insert_user("oppilas", "oppilas", "Tessa", "Testaaja") j = db.account.join(db.role) sql = Select([func.count(db.account.c.username), db.role.c.name]).select_from(j) with db.engine.connect() as conn: rs = conn.execute(sql) row = rs.first() count = row[0] role = row[1] assert 1 == count assert "USER" == role db.insert_user("opettaja", "opettaja", "Essi", "Esimerkki", role="TEACHER") sql = Select([func.count(db.account.c.username), db.role.c.name ]).select_from(j).where(db.role.c.name == "TEACHER") with db.engine.connect() as conn: rs = conn.execute(sql) row = rs.first() count = row[0] role = row[1] assert 1 == count assert "TEACHER" == role student = db.get_user_by_id(1) teacher = db.get_user_by_id(2) null = db.get_user_by_id(3) assert student.name == "oppilas" assert teacher.name == "opettaja" assert null == None
def limit_query(query: Select, limit: Optional[str] = None, offset: Optional[str] = None) -> Select: if limit: validators.raise_if_not_int(limit) query = query.limit(int(limit)) if offset: validators.raise_if_not_int(offset) query = query.offset(int(offset)) return query
def paginate(select: Select, page_position: PagePosition) -> Select: if page_position.sort: order_by_clauses = [ _sort_direction_map[order.direction](order.field) for order in page_position.sort.orders ] select = select.order_by(*order_by_clauses) return select.limit(page_position.limit).offset(page_position.offset)
def get_my_boards_filter(cls, request: Request, query: Select, value: bool = False) -> Select: if value: query = query.where(cls.board_alias.c.author_id == request.user.id) else: query = query.where(cls.board_alias.c.author_id != request.user.id) return query
def _apply_ip_or_cidr_filter( self, stmt: Select, ip_or_cidr: str ) -> Select: """Apply an appropriate filter for an IP or CIDR block. Notes ----- If there is ever a need to support a database that does not have native CIDR membership queries, fallback code (probably using a LIKE expression) will need to be added here. """ if "/" in ip_or_cidr: return stmt.where(text(":c >> ip_address")).params(c=ip_or_cidr) else: return stmt.where(TokenChangeHistory.ip_address == str(ip_or_cidr))
def test_execute_select_process_result_value(mocked_client, mocker) -> None: mocked_client.execute_statement.return_value = { 'numberOfRecordsUpdated': 0, 'records': [[{ 'longValue': 1 }, { 'stringValue': 'cat' }]], 'columnMetadata': [ { "arrayBaseColumnType": 0, "isAutoIncrement": False, "isCaseSensitive": False, "isCurrency": False, "isSigned": True, "label": "id", "name": "id", "nullable": 1, "precision": 11, "scale": 0, "schemaName": "", "tableName": "pets", "type": 4, "typeName": "INT", }, { "arrayBaseColumnType": 0, "isAutoIncrement": False, "isCaseSensitive": False, "isCurrency": False, "isSigned": False, "label": "name", "name": "name", "nullable": 1, "precision": 255, "scale": 0, "schemaName": "", "tableName": "pets", "type": 12, "typeName": "VARCHAR", }, ], } data_api = DataAPI( resource_arn='arn:aws:rds:dummy', secret_arn='dummy', database='test', client=mocked_client, ) assert list(data_api.execute(Select([Pets]))[0]) == [1, 'my_type_cat'] assert mocked_client.execute_statement.call_args == mocker.call( continueAfterTimeout=True, database='test', includeResultMetadata=True, resourceArn='arn:aws:rds:dummy', secretArn='dummy', sql="""SELECT pets.id, pets.name FROM pets""", )
async def get_exercises(workout_id): w = WorkoutExercise.__table__ e = Exercise.__table__ query = Select(columns=[*w.c, e.c.name.label('exercise_name')]) \ .select_from(w.join(e)) \ .where(w.c.workout_id == workout_id) return await db.fetch_all(query)
def apply_join(query: Select, table: Table, join_table: Table, join: TableJoin): """ Performs a inner or outer join between two tables on a given query object. TODO: enable multiple joins :param query: A SQLAlchemy select object. :param table: The Table we are joining from. :param join_table: The Table we are joining to. :param join: The Join object describing how to join the tables. :return: A SQLAlchemy select object modified to join two tables. """ error_msg = 'Invalid join, "{}" is not a column on table "{}"' join_conditions = [] for column_pair in join.column_pairs: from_col = table.columns.get(column_pair.from_column) to_col = join_table.columns.get(column_pair.to_column) if from_col is None: raise ValueError(error_msg.format(column_pair.from_column, table.name)) if to_col is None: raise ValueError(error_msg.format(column_pair.to_column, join_table.name)) join_conditions.append(from_col == to_col) return query.select_from(table.join(join_table, onclause=and_(*join_conditions), isouter=join.outer_join))
def get_search_clause(self, table: sa.Table, query: Select, search: str, search_fields: Sequence[str]) -> Select: if not search: return query columns = [getattr(table.c, col) for col in search_fields] return query.where(or_(*(col.ilike(f"%{search}%") for col in columns)))
def paginate( query: Select, columns: Dict[Column, Any], order: Order, limit: int, ) -> Select: orderer = get_orderer(order) comparator = get_comparator(order) for column, value in columns.items(): query = query.order_by(orderer(column)) if value is not None: query = query.where(comparator(column, value)) return query.limit(limit)
def callPGLoadPayloadTuplesBlocking(dbSessionCreator: DbSessionCreator, sql: Select, sqlCoreLoadTupleClassmethod: Callable, payloadFilt: Optional[Dict] = None, fetchSize=50) -> LoadPayloadTupleResult: payloadFileJson = ujson.dumps(payloadFilt if payloadFilt else {}) sqlStr = str( sql.compile(dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True})) loaderModuleClassMethodStr = '.'.join([ sqlCoreLoadTupleClassmethod.__self__.__module__, sqlCoreLoadTupleClassmethod.__self__.__name__, sqlCoreLoadTupleClassmethod.__name__ ]) session = dbSessionCreator() try: sqlFunc = func.peek_storage.load_paylaod_tuples( sqlStr, payloadFileJson, loaderModuleClassMethodStr, __sysPathsJson, fetchSize) resultJsonStr: str = next(session.execute(sqlFunc))[0] resultJson: Dict = ujson.loads(resultJsonStr) if resultJson["encodedPayload"]: resultJson["encodedPayload"] = resultJson["encodedPayload"].encode( ) return LoadPayloadTupleResult(**resultJson) finally: session.close()
async def list(user_id: str): e = Exercise.__table__ w = Workout.__table__ query = Select(columns=[*e.c, w.c.date.label('last_workout_date')]) \ .select_from(e.outerjoin(w)) \ .where(e.c.user_id == user_id) \ .where(e.c.is_deleted == false()) \ .order_by(nullslast(desc(w.c.date))) return await db.fetch_all(query)
def compile_sqla_query(self, qry: Select, schema: Optional[str] = None) -> str: engine = self.get_sqla_engine(schema=schema) sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True})) # pylint: disable=protected-access if engine.dialect.identifier_preparer._double_percents: # noqa sql = sql.replace("%%", "%") return sql
async def view_exercise_history(exercise_id: str): we = WorkoutExercise.__table__ workouts = Workout.__table__ query = Select(columns=[*we.c, workouts.c.date.label('workout_date'), workouts.c.id.label('workout_id')]) \ .select_from(we.join(workouts)) \ .where(workouts.c.is_deleted == false()) \ .where(we.c.exercise_id == exercise_id) \ .order_by(desc(workouts.c.date)) return await db.fetch_all(query)
async def check_entity_belongs_to_user(table: Table, entity_id: str, user_id: str): query = Select(columns=[func.count().label('cnt')]) \ .select_from(table) \ .where(table.c.user_id == user_id) \ .where(table.c.is_deleted == false()) \ .where(table.c.id == entity_id) row = await db.fetch_one(query) result = dict(row)['cnt'] != 0 if row is not None else False if not result: raise CustomException('Недостаточно прав для просмотра записи')
def create_single_query_mock(data_source_name): """ convenience fn to create sqlalchemy's Select clausewith some column and table. """ return Select( columns=[ column(data_source_name + '_column_mock'), column(HUSKY_QUERY_DATA_SOURCE_COLUMN_NAME) ], from_obj=text(data_source_name + '_table_mock'), )
async def filter_queryset(self, queryset: Select) -> Select: validated_data = self.request['validated_data'] group_name = validated_data['group'] version = validated_data.get('version') filters = [ switches.c.groups.contains(f'{{{group_name}}}'), switches.c.is_active == true(), switches.c.is_hidden == false(), ] if version is not None: filters.append(switches.c.version <= version) return queryset.where(and_(*filters))
def _apply_cursor(stmt: Select, cursor: HistoryCursor) -> Select: """Apply a cursor to a query.""" time = datetime_to_db(cursor.time) if cursor.previous: return stmt.where( or_( TokenChangeHistory.event_time > time, and_( TokenChangeHistory.event_time == time, TokenChangeHistory.id > cursor.id, ), ) ) else: return stmt.where( or_( TokenChangeHistory.event_time < time, and_( TokenChangeHistory.event_time == time, TokenChangeHistory.id <= cursor.id, ), ) )
def main(**kwargs): # use ti for xcom push and pull ti = kwargs["ti"] # try to read user buffer csv # if not found set param_id = 0 user_id = 0 buffer_dir = "/Users/muhammadsyamsularifin/airflow/buffer_data/users.csv" try: df = pd.read_csv(buffer_dir) user_id = int(df.tail(1)["id"]) except FileNotFoundError: pass # select data from db, with id greater than variable "user_id" query = Select([Users]).where(Users.c.id>user_id).limit(3) result = conn.execute(query) # prepare pandas dataframe, new and empty column = [ "id", "username", "address", "is_active", "domicile", "balance", "point" ] df = pd.DataFrame(columns=column) # insert it into pandas for id, username, address, is_active, domicile, balance, point in result: new_row = { "id": id, "username": username, "address": address, "is_active": is_active, "domicile": domicile, "balance": float(balance), "point": float(point) } df = df.append(new_row, ignore_index=True) # if no data extracted, tell xcom that it is done if (len(df["id"]) == 0): ti.xcom_push(key="extract_user_done", value=1) return None # save to csv file buffer_dir = "/Users/muhammadsyamsularifin/airflow/buffer_data/users.csv" df.to_csv(buffer_dir, index=False)
def visit(self): self._expr['selectList']['key'] = 'selectList' clause = Select(self._expr['selectList']).visit() self._expr['selectList'].pop('key') for key, expr in self._expr.items(): if key == 'key': continue if key == 'selectList': continue expr['key'] = key clause = self.access(clause, expr) expr.pop('key') return clause
def test_db_file_insert_constraint(db_test_client): from application import db s="nothing" db.insert_user(s,s,s,s,role="TEACHER") user_id = db.get_user(s,s).get_id() with pytest.raises(IntegrityError): file_insert_helper(db, user_id=user_id, binary_file=io.BytesIO(b"helvetin turhia bitteja").read()) sql = Select([db.file]) with db.engine.connect() as conn: rs = conn.execute(sql) row = rs.first() assert row is None
def _project_columns( cls, query: Select, dataframe: Dataframe, return_taxons: Dict[TaxonExpressionStr, Taxon] ) -> Tuple[List[ColumnAndDataframeColumn], Select]: projected_sql_and_df_columns: List[ColumnAndDataframeColumn] = [ cls._project_column( query, taxon, dataframe.slug_to_column.get(taxon_slug_expression)) for taxon_slug_expression, taxon in return_taxons.items() ] return ( projected_sql_and_df_columns, Select(columns=sort_columns( [col for col, _ in projected_sql_and_df_columns])), )
def _apply_limit(query: Select, args: dict) -> Select: """ If a limit has been supplied by the 'length' parameter (and an offset supplied by the 'start' parameter) from a DataTable AJAX request, this adds it to the given query. A length of -1 represents no limit or offset. The offset index is 0-based. :param query: the DataTable SQL query :param args: the query parameters from a DataTable AJAX request :return: the query with the limit (if it exists) applied """ limit = int(args['length']) if limit != -1: offset = int(args.get('start', 0)) query = query.limit(limit).offset(offset) return query
async def paginate( db: Database, query: Select, params: Optional[AbstractParams] = None, *, convert_to_mapping: bool = True, ) -> AbstractPage: params = resolve_params(params) total = await db.fetch_val( select([func.count()]).select_from(query.alias())) items = await db.fetch_all(paginate_query(query, params)) if convert_to_mapping: items = [{**item} for item in items] return create_page(items, total, params)
def _apply_text_filter(query: Select, args: dict, column: Column) -> Select: """ If a value has been specified by the ['search']['value'] parameter from a DataTable AJAX request, this adds "column LIKE %['search']['value']%" to the WHERE clause of the given query. :param query: the DataTable SQL query :param args: the query parameters from a DataTable AJAX request :param column: the column to be filtered :return: the query with the filter (if it exists) applied """ search_value = args['search']['value'] if search_value: query = query.where( column.like('%{}%'.format(search_value)) ) return query
def _paginate( query: Select, session: Session, total_items: int, offset: int, limit: int, ) -> Page[T]: total_pages = math.ceil(total_items / limit) page_number = offset / limit + 1 query = query.offset(offset).limit(limit) result = session.execute(query) return Page[T]( data=iter(result.unique().scalars()), meta={ "offset": offset, "total_items": total_items, "total_pages": total_pages, "page_number": page_number, }, )
async def paginate(query: Select, params: Optional[AbstractParams] = None) -> AbstractPage: params = resolve_params(params) try: is_loader_used = query._execution_options[ "loader"]._distinct # type: ignore except (AttributeError, KeyError): is_loader_used = False if is_loader_used: # FIXME: find better way to fetch rows when loader is used items = await query.gino.all() # type: ignore return base_paginate(items, params) total = await func.count().select().select_from(query.alias() ).gino.scalar() items = await paginate_query(query, params).gino.all() # type: ignore return create_page(items, total, params)
def calculate_dataframe( cls, dimension_formulas: List[PreFormula], override_mappings_tel_data: OverrideMappingTelData, override_mapping_cte_map: Dict[OverrideMappingSlug, Select], df: Dataframe, ) -> Dataframe: select_columns = [] select_columns.extend(df.query.columns) for dim_formula in dimension_formulas: col = dim_formula.formula.label(dim_formula.label) select_columns.append(col) # add joins to relevant override mapping CTEs select_from_query = OverrideMappingSql.insert_cte_joins( df.query, override_mappings_tel_data, override_mapping_cte_map) query = Select(columns=sort_columns(select_columns)).select_from( select_from_query) return Dataframe(query, df.slug_to_column, df.used_model_names, df.used_physical_data_sources)
def paginate( self, query: Select, item_factory: Callable[[Row], ModelT], *, sort_model: Base = None, custom_sort: str = None, ) -> Page[ModelT]: total = Session.execute( select(func.count()). select_from(query.subquery()) ).scalar_one() try: if sort_model: sort_col = getattr(sort_model, self.sort) elif custom_sort: sort_col = text(custom_sort) else: sort_col = self.sort limit = self.size or total items = [ item_factory(row) for row in Session.execute( query. order_by(sort_col). offset(limit * (self.page - 1)). limit(limit) ) ] except (AttributeError, CompileError): raise HTTPException(HTTP_422_UNPROCESSABLE_ENTITY, 'Invalid sort column') return Page( items=items, total=total, page=self.page, pages=ceil(total / limit) if limit else 0, )
def _apply_ordering(query: Select, args: dict, default_sort_column_name: str, default_direction: str='DESC') -> Select: """ If an ordering has been supplied by the ['order'] parameter from a DataTable AJAX request, this adds it to the given query. Otherwise, it adds the ordering specified by the default_sort_column_name and default_direction. :param query: the DataTable SQL query :param args: the query parameters from a DataTable AJAX request :param default_sort_column_name: the name of the column for the default ordering :param default_direction: the default sort direction (DESC if not specified) :return: the query with the ordering applied """ ord = args['order'] default_sort = '{} {}'.format(default_sort_column_name, default_direction) order_by = ', '.join(_get_orderings(args)) if ord else default_sort return query.order_by(order_by + ' NULLS LAST')
def augment_query( cls, ctx: HuskyQueryContext, query: Select, taxon_model_info_map: Dict[str, TaxonModelInfo], filter_clause: Optional[FilterClause], ) -> Select: """ Adds filters to the query :param ctx: Husky query runtime :param query: Original query :param taxon_model_info_map: Map of taxon slug expression to taxon model info :param filter_clause: Filter clauses :return: New query with all modifiers applied """ if filter_clause: new_q = query.where(filter_clause.generate(ctx, query, taxon_model_info_map)) return new_q else: return query
def render_direct_mapping(cls, mapping: OverrideMapping) -> Select: """Renders CTE for direct mapping as union of all values""" selects = [] for original, changed in mapping.definition: # using "literal" instead of "literal_column" here to force SQLAlchemy to bind constants as params (safe) if original is None: original_column = literal_column('CAST(NULL AS VARCHAR)') else: original_column = literal(original) if changed is None: changed_column = literal(cls.PANO_NULL) else: changed_column = literal(changed) selects.append( Select([ original_column.label(cls.ORIGINAL_COLUMN_NAME), changed_column.label(cls.CHANGED_COLUMN_NAME) ])) return union_all(*selects)
def change_objects_ownership(engine: Engine, database: str, target_role: str) -> None: stmt = Select( [ literal_column("table_type"), literal_column("table_schema"), literal_column("table_name"), ], from_obj=text(f"{database}.INFORMATION_SCHEMA.TABLES"), whereclause=literal_column("table_owner") == "DBT_PRODUCTION", ) with engine.begin() as tx: rp = tx.execute(stmt) objects = [( "TABLE" if object_type == "BASE TABLE" else object_type, schema, object_name, ) for object_type, schema, object_name in rp.fetchall()] for object_type, schema, object_name in objects: tx.execute( f"GRANT OWNERSHIP ON {object_type} {database}.{schema}.{object_name} TO ROLE {target_role} REVOKE CURRENT GRANTS" ).fetchall()