def _build_selectors( self, get_column_name: Callable[[TaxonSlugExpression, str], str], ) -> Select: """ Returns the select part of query. """ selectors = [] for taxon_slug_expression, taxon in self.projection_taxons.items(): column_name = get_column_name(taxon_slug_expression, taxon_slug_expression.slug) col = literal_column(column_name) selectors.append(col.label(taxon.slug_safe_sql_identifier)) for template in self.dimension_templates: # We must render the dimension templates with correct sql columns slug_to_column = { slug: get_column_name(TaxonSlugExpression(slug), slug) for slug in template.used_taxons } sql_formula = template.render_formula(**slug_to_column) col = literal_column(sql_formula).label(template.label) selectors.append(col) return select(sort_columns(selectors))
def _get_load_associations_query(document, doc_types_to_load): query_parents = DBSession. \ query( Association.parent_document_id.label('id'), Document.type.label('t'), literal_column('1').label('p')). \ join( Document, and_( Association.child_document_id == document.document_id, Association.parent_document_id == Document.document_id, Document.type.in_(doc_types_to_load))). \ subquery() query_children = DBSession. \ query( Association.child_document_id.label('id'), Document.type.label('t'), literal_column('0').label('p')). \ join( Document, and_( Association.child_document_id == Document.document_id, Association.parent_document_id == document.document_id, Document.type.in_(doc_types_to_load))). \ subquery() return DBSession \ .query('id', 't', 'p') \ .select_from(union(query_parents.select(), query_children.select()))
def _get_load_associations_query(document, doc_types_to_load): query_parents = DBSession. \ query( Association.parent_document_id.label('id'), Association.parent_document_type.label('t'), literal_column('1').label('p')). \ filter( and_( Association.child_document_id == document.document_id, Association.parent_document_type.in_(doc_types_to_load) ) ). \ subquery() query_children = DBSession. \ query( Association.child_document_id.label('id'), Association.child_document_type.label('t'), literal_column('0').label('p')). \ filter( and_( Association.parent_document_id == document.document_id, Association.child_document_type.in_(doc_types_to_load) ) ). \ subquery() return DBSession \ .query(column('id'), column('t'), column('p')) \ .select_from(union(query_parents.select(), query_children.select()))
def get_partition_by_columns(self, model: HuskyModel): return [ literal_column( model.taxon_sql_accessor(self.ctx, model_attribute.taxon)) for model_attribute in model.attributes_memoized.values() if model_attribute.identifier is True ]
async def creator( gid: GroupID, cluster_access_rights: Dict[GroupID, ClusterAccessRights] = None ) -> Cluster: new_cluster = ClusterCreate( **{ "name": faker.name(), "type": random.choice(list(ClusterType)), "owner": gid, "access_rights": cluster_access_rights or {}, }) result = postgres_db.execute(clusters.insert().values( new_cluster.dict(by_alias=True, exclude={"id", "access_rights" })).returning(literal_column("*"))) cluster_in_db = result.first() assert cluster_in_db is not None new_cluster_id = cluster_in_db[clusters.c.id] list_of_created_cluster_ids.append(new_cluster_id) # when a cluster is created, the DB automatically creates the owner access rights for group_id, access_rights in new_cluster.access_rights.items(): result = postgres_db.execute( insert(cluster_to_groups).values( **{ "cluster_id": new_cluster_id, "gid": group_id, "read": access_rights.read, "write": access_rights.write, "delete": access_rights.delete, }).on_conflict_do_nothing()) return Cluster(id=new_cluster_id, **new_cluster.dict(by_alias=True, exclude={"id"}))
async def test_listen_comp_tasks_task( mock_project_subsystem: Dict, comp_task_listening_task: asyncio.Task, client, update_values: Dict[str, Any], expected_calls: List[str], task_class: NodeClass, ): db_engine: aiopg.sa.Engine = client.app[APP_DB_ENGINE_KEY] async with db_engine.acquire() as conn: # let's put some stuff in there now result = await conn.execute(comp_tasks.insert().values( outputs=json.dumps({}), node_class=task_class).returning(literal_column("*"))) row: RowProxy = await result.fetchone() task = dict(row) # let's update some values await conn.execute(comp_tasks.update().values(**update_values).where( comp_tasks.c.task_id == task["task_id"])) # tests whether listener gets hooked calls executed for call_name, mocked_call in mock_project_subsystem.items(): if call_name in expected_calls: async for attempt in _async_retry_if_fails(): with attempt: mocked_call.assert_awaited() else: mocked_call.assert_not_called()
async def test_listen_comp_tasks_task( mock_project_subsystem: Dict, comp_task_listening_task: asyncio.Task, client, upd_value: Dict[str, Any], exp_calls: List[str], task_class: NodeClass, ): db_engine: aiopg.sa.Engine = client.app[APP_DB_ENGINE_KEY] async with db_engine.acquire() as conn: # let's put some stuff in there now result = await conn.execute( comp_tasks.insert() .values(outputs=json.dumps({}), node_class=task_class) .returning(literal_column("*")) ) row: RowProxy = await result.fetchone() task = dict(row) # let's update some values await conn.execute( comp_tasks.update() .values(**upd_value) .where(comp_tasks.c.task_id == task["task_id"]) ) for key, mock_fct in mock_project_subsystem.items(): if key in exp_calls: await _wait_for_call(mock_fct) else: mock_fct.assert_not_called()
async def create( self, user_id: UserID, project_id: ProjectID, cluster_id: ClusterID, default_cluster_id: ClusterID, iteration: Optional[PositiveInt] = None, ) -> CompRunsAtDB: async with self.db_engine.acquire() as conn: if iteration is None: # let's get the latest if it exists last_iteration = await conn.scalar( sa.select([ comp_runs.c.iteration ]).where((comp_runs.c.user_id == user_id) & (comp_runs.c.project_uuid == str(project_id))). order_by(desc(comp_runs.c.iteration))) iteration = (last_iteration or 1) + 1 result = await conn.execute( comp_runs.insert() # pylint: disable=no-value-for-parameter .values( user_id=user_id, project_uuid=f"{project_id}", cluster_id=cluster_id if cluster_id != default_cluster_id else None, iteration=iteration, result=RUNNING_STATE_TO_DB[RunningState.PUBLISHED], started=datetime.utcnow(), ).returning(literal_column("*"))) row = await result.first() return CompRunsAtDB.from_orm(row)
def select_vc_time_per_user(type_id: int, lit_values: list) -> Select: join_time = date_to_secs(VoiceChatEvent.created_at) left_time = date_to_secs(VoiceChatEvent.updated_at) value_column = func.sum(left_time - join_time).label('value') lit_columns = [literal_column(str(v)).label(l) for (l, v) in lit_values] select_columns = [value_column, VoiceChatEvent.user_id] + lit_columns return select(select_columns).where( VoiceChatEvent.type_id == type_id).group_by(VoiceChatEvent.user_id)
async def create_user(db, data): async with db() as conn: created = await conn.execute(user.insert() .returning(literal_column('id, email')) .values(**data)) new_user = await created.fetchone() return dict(new_user)
def select_membership_time_per_user(type_id: int, lit_values: list) -> Select: join_time = date_to_secs(func.max(MemberEvent.created_at)) current_time = int(datetime.now().timestamp()) membership_value = cast((current_time - join_time) / 86400, Integer).label('value') lit_columns = [literal_column(str(v)).label(l) for (l, v) in lit_values] select_columns = [membership_value, MemberEvent.user_id] + lit_columns return select(select_columns).where( and_(MemberEvent.type_id == type_id, User.roles != None)).group_by(MemberEvent.user_id)
def get_actors(engine): q = select( [literal_column('actor_id')] ) \ .select_from(table('actor')) \ .where( and_( literal_column('first_name') == 'JOHNNY', literal_column('last_name') == 'LOLLOBRIGIDA', ) ) session = get_session(engine=engine) try: rsp = session.execute(q) res = rsp.first() finally: print('do finally') session.close() return res
def get_neighbour_version_ids(version_id, document_id, lang): """ Get the previous and next version for a version of a document with a specific language. """ next_version = DBSession \ .query( DocumentVersion.id.label('id'), literal_column('1').label('t')) \ .filter(DocumentVersion.id > version_id) \ .filter(DocumentVersion.document_id == document_id) \ .filter(DocumentVersion.lang == lang) \ .order_by(DocumentVersion.id) \ .limit(1) \ .subquery() previous_version = DBSession \ .query( DocumentVersion.id.label('id'), literal_column('-1').label('t')) \ .filter(DocumentVersion.id < version_id) \ .filter(DocumentVersion.document_id == document_id) \ .filter(DocumentVersion.lang == lang) \ .order_by(DocumentVersion.id.desc()) \ .limit(1) \ .subquery() query = DBSession \ .query('id', 't') \ .select_from(union( next_version.select(), previous_version.select())) previous_version_id = None next_version_id = None for version, typ in query: if typ == -1: previous_version_id = version else: next_version_id = version return previous_version_id, next_version_id
def _get_select_children(waypoint): """ Return a WITH query that selects the document ids of the given waypoint, the children and the grand-children of the waypoint. See also: http://docs.sqlalchemy.org/en/latest/core/selectable.html#sqlalchemy.sql.expression.GenerativeSelect.cte # noqa """ select_waypoint = DBSession. \ query( literal_column(str(waypoint.document_id)).label('document_id'), literal_column('1').label('priority')). \ cte('waypoint') # query to get the direct child waypoints select_waypoint_children = DBSession. \ query( Waypoint.document_id, literal_column('0').label('priority')). \ join( Association, and_(Association.child_document_id == Waypoint.document_id, Association.parent_document_id == waypoint.document_id)). \ cte('waypoint_children') # query to get the grand-child waypoints select_waypoint_grandchildren = DBSession. \ query( Waypoint.document_id, literal_column('0').label('priority')). \ select_from(select_waypoint_children). \ join( Association, Association.parent_document_id == select_waypoint_children.c.document_id). \ join( Waypoint, Association.child_document_id == Waypoint.document_id). \ cte('waypoint_grandchildren') return union( select_waypoint.select(), select_waypoint_children.select(), select_waypoint_grandchildren.select()). \ cte('select_all_waypoints')
async def update(self, user_id: UserID, project_id: ProjectID, iteration: PositiveInt, **values) -> Optional[CompRunsAtDB]: async with self.db_engine.acquire() as conn: result = await conn.execute( sa.update(comp_runs).where( (comp_runs.c.project_uuid == str(project_id)) & (comp_runs.c.user_id == str(user_id)) & (comp_runs.c.iteration == iteration)).values( **values).returning(literal_column("*"))) row: RowProxy = await result.first() return CompRunsAtDB.from_orm(row) if row else None
def __iter__(self): """Main interface to get the data, returns assets and then policies.""" organization, folders, tables, policies, group_membership = \ create_table_names(self.snapshot.cycle_timestamp) forseti_org = self.session.query(organization).one() yield "organizations", forseti_org # Folders folder_set = (self.session.query(folders).filter( folders.parent_type == 'organization').all()) while folder_set: for folder in folder_set: yield 'folders', folder folder_set = (self.session.query(folders).filter( folders.parent_type == 'folder').filter( folders.parent_id.in_([f.folder_id for f in folder_set])).all()) for res_type, table in tables: for item in self.session.query(table).yield_per(PER_YIELD): yield res_type, item membership, groups = group_membership query_groups = (self.session.query(groups).with_entities( literal_column("'GROUP'"), groups.group_email)) principals = query_groups.distinct() for kind, email in principals.yield_per(PER_YIELD): yield kind.lower(), email query = (self.session.query( membership, groups).filter(membership.group_id == groups.group_id).order_by( desc(membership.member_email)).distinct()) cur_member = None member_groups = [] for member, group in query.yield_per(PER_YIELD): if cur_member and cur_member.member_email != member.member_email: if cur_member: yield 'membership', (cur_member, member_groups) cur_member = None member_groups = [] cur_member = member member_groups.append(group) for policy_table in policies: for policy in self.session.query(policy_table).all(): yield 'policy', policy
def build_aggregate_group_by( self, table_columns: List[FreeAggregateColumn], base_statement: SQLAlchemyStatement ) -> Tuple[bool, SQLAlchemyStatement]: non_group_columns = ArrayHelper(table_columns) \ .filter(lambda x: x.arithmetic is None or x.arithmetic == FreeAggregateArithmetic.NONE) \ .to_list() if len(non_group_columns) != 0 and len(non_group_columns) != len(table_columns): # only when aggregation column exists, group by needs to be appended return True, base_statement.group_by( *ArrayHelper(non_group_columns).map(lambda x: literal_column(x.name)).to_list()) else: # otherwise return statement itself return False, base_statement
def _get_select_children(waypoint): """ Return a WITH query that selects the document ids of the given waypoint, the children and the grand-children of the waypoint. See also: http://docs.sqlalchemy.org/en/latest/core/selectable.html#sqlalchemy.sql.expression.GenerativeSelect.cte # noqa """ select_waypoint = DBSession.query( literal_column(str(waypoint.document_id)).label("document_id"), literal_column("1").label("priority") ).cte("waypoint") # query to get the direct child waypoints select_waypoint_children = ( DBSession.query(Association.child_document_id.label("document_id"), literal_column("0").label("priority")) .filter( and_( Association.child_document_type == WAYPOINT_TYPE, Association.parent_document_id == waypoint.document_id ) ) .cte("waypoint_children") ) # query to get the grand-child waypoints select_waypoint_grandchildren = ( DBSession.query(Association.child_document_id.label("document_id"), literal_column("0").label("priority")) .select_from(select_waypoint_children) .join( Association, and_( Association.parent_document_id == select_waypoint_children.c.document_id, Association.child_document_type == WAYPOINT_TYPE, ), ) .cte("waypoint_grandchildren") ) return union( select_waypoint.select(), select_waypoint_children.select(), select_waypoint_grandchildren.select() ).cte("select_all_waypoints")
async def task( db_connection: SAConnection, db_notification_queue: asyncio.Queue, task_class: NodeClass, ) -> Dict: result = await db_connection.execute(comp_tasks.insert().values( outputs=json.dumps({}), node_class=task_class).returning(literal_column("*"))) row: RowProxy = await result.fetchone() task = dict(row) assert (db_notification_queue.empty( )), "database triggered change although it should only trigger on updates!" yield task
def update_areas_of_changes(document): """Update the area ids of all feed entries of the given document. """ areas_select = select( [ # concatenate with empty array to avoid null values # select ARRAY[]::integer[] || array_agg(area_id) literal_column('ARRAY[]::integer[]').op('||')( func.array_agg( AreaAssociation.area_id, type_=postgresql.ARRAY(Integer))) ]).\ where(AreaAssociation.document_id == document.document_id) DBSession.execute(DocumentChange.__table__.update().where( DocumentChange.document_id == document.document_id).values( area_ids=areas_select.as_scalar()))
def src_vector_intersects(self) -> bool: try: logger.debug( f"Check if tile {self.tile_id} intersects with postgis table") conn = psycopg2.connect( dbname=self.src.conn.db_name, user=self.src.conn.db_user, password=self.src.conn.db_password, host=self.src.conn.db_host, port=self.src.conn.db_port, ) cursor = conn.cursor() sql = (select([literal_column("1")]).select_from( self.src_table()).where(self.intersect_filter())) # exists_query = select([literal_column("exists")]).select_from(select_1) logger.debug(str(sql)) cursor.execute(str(sql)) try: exists = bool(cursor.fetchone()[0]) except (ProgrammingError, TypeError): exists = False cursor.close() conn.close() except psycopg2.Error: logger.exception( "There was an issue when trying to connect to the database") raise logger.debug(f"EXISTS: {exists}") if exists: logger.info( f"Tile id {self.tile_id} exists in database table {self.src.schema}.{self.src.table}" ) else: logger.info( f"Tile id {self.tile_id} does not exists in database table {self.src.schema}.{self.src.table}" ) return exists
def update_map(topo_map, reset=False): """Create associations for the given map with all intersecting documents. If `reset` is True, all possible existing associations to this map are dropped before creating new associations. """ if reset: DBSession.execute( TopoMapAssociation.__table__.delete().where( TopoMapAssociation.topo_map_id == topo_map.document_id) ) if topo_map.redirects_to: # ignore forwarded maps return map_geom = select([DocumentGeometry.geom_detail]). \ where(DocumentGeometry.document_id == topo_map.document_id) intersecting_documents = DBSession. \ query( DocumentGeometry.document_id, # id of a document literal_column(str(topo_map.document_id))). \ join( Document, and_( Document.document_id == DocumentGeometry.document_id, Document.type != MAP_TYPE)). \ filter(Document.redirects_to.is_(None)). \ filter( or_( DocumentGeometry.geom.ST_Intersects( map_geom.label('t1')), DocumentGeometry.geom_detail.ST_Intersects( map_geom.label('t2')) )) DBSession.execute( TopoMapAssociation.__table__.insert().from_select( [TopoMapAssociation.document_id, TopoMapAssociation.topo_map_id], intersecting_documents)) # update cache key for now associated docs update_cache_version_for_map(topo_map)
def update_areas_of_changes(document): """Update the area ids of all feed entries of the given document. """ areas_select = select( [ # concatenate with empty array to avoid null values # select ARRAY[]::integer[] || array_agg(area_id) literal_column('ARRAY[]::integer[]').op('||')( func.array_agg( AreaAssociation.area_id, type_=postgresql.ARRAY(Integer))) ]).\ where(AreaAssociation.document_id == document.document_id) DBSession.execute( DocumentChange.__table__.update(). where(DocumentChange.document_id == document.document_id). values(area_ids=areas_select.as_scalar()) )
def _append_returning(self, columns: Union[str, List[str]], query: UpdateBase) -> Tuple[UpdateBase, bool]: column_names: List[str] = _normalize(columns) is_scalar: bool = len(column_names) == 1 if PRIMARY_KEY in column_names: # defaults to primery key query = query.returning(self._primary_key) elif ALL_COLUMNS in column_names: query = query.returning(literal_column("*")) is_scalar = False # NOTE: returning = self._table would also work. less efficient? else: # selection query = query.returning( *[self._table.c[name] for name in column_names]) return query, is_scalar
async def test_listen_comp_tasks_task( mock_project_subsystem: Dict, comp_task_listening_task: None, client, update_values: Dict[str, Any], expected_calls: List[str], task_class: NodeClass, ): db_engine: aiopg.sa.Engine = client.app[APP_DB_ENGINE_KEY] async with db_engine.acquire() as conn: # let's put some stuff in there now result = await conn.execute( comp_tasks.insert() .values(outputs=json.dumps({}), node_class=task_class) .returning(literal_column("*")) ) row: RowProxy = await result.fetchone() task = dict(row) # let's update some values await conn.execute( comp_tasks.update() .values(**update_values) .where(comp_tasks.c.task_id == task["task_id"]) ) # tests whether listener gets hooked calls executed for call_name, mocked_call in mock_project_subsystem.items(): if call_name in expected_calls: async for attempt in AsyncRetrying( wait=wait_fixed(1), stop=stop_after_delay(10), retry=retry_if_exception_type(AssertionError), before_sleep=before_sleep_log(logger, logging.INFO), reraise=True, ): with attempt: mocked_call.assert_awaited() else: mocked_call.assert_not_called()
def update_area(area, reset=False): """Create associations for the given area with all intersecting documents. If `reset` is True, all possible existing associations to this area are dropped before creating new associations. """ if reset: DBSession.execute( AreaAssociation.__table__.delete().where( AreaAssociation.area_id == area.document_id) ) if area.redirects_to: # ignore forwarded areas return area_geom = select([DocumentGeometry.geom_detail]). \ where(DocumentGeometry.document_id == area.document_id) intersecting_documents = DBSession. \ query( DocumentGeometry.document_id, # id of a document literal_column(str(area.document_id))). \ join( Document, and_( Document.document_id == DocumentGeometry.document_id, Document.type != AREA_TYPE)). \ filter(Document.redirects_to.is_(None)). \ filter( or_( DocumentGeometry.geom.ST_Intersects( area_geom.label('t1')), DocumentGeometry.geom_detail.ST_Intersects( area_geom.label('t2')) )) DBSession.execute( AreaAssociation.__table__.insert().from_select( [AreaAssociation.document_id, AreaAssociation.area_id], intersecting_documents))
def update_maps_for_document(document, reset=False): """Create associations for the given documents with all intersecting maps. If `reset` is True, all possible existing associations to this document are dropped before creating new associations. """ if reset: DBSession.execute( TopoMapAssociation.__table__.delete().where( TopoMapAssociation.document_id == document.document_id) ) if document.redirects_to: # ignore forwarded maps return document_geom = select([DocumentGeometry.geom]). \ where(DocumentGeometry.document_id == document.document_id) document_geom_detail = select([DocumentGeometry.geom_detail]). \ where(DocumentGeometry.document_id == document.document_id) intersecting_maps = DBSession. \ query( DocumentGeometry.document_id, # id of a map literal_column(str(document.document_id))). \ join( TopoMap, TopoMap.document_id == DocumentGeometry.document_id). \ filter(TopoMap.redirects_to.is_(None)). \ filter( or_( DocumentGeometry.geom_detail.ST_Intersects( document_geom.label('t1')), DocumentGeometry.geom_detail.ST_Intersects( document_geom_detail.label('t2')) )) DBSession.execute( TopoMapAssociation.__table__.insert().from_select( [TopoMapAssociation.topo_map_id, TopoMapAssociation.document_id], intersecting_maps))
def build_free_aggregate_column(self, table_column: FreeAggregateColumn, index: int, prefix_name: str) -> Label: name = table_column.name alias = f'{prefix_name}_{index + 1}' arithmetic = table_column.arithmetic if arithmetic == FreeAggregateArithmetic.COUNT: return func.count(literal_column(name)).label(alias) elif arithmetic == FreeAggregateArithmetic.SUMMARY: return func.sum(literal_column(name)).label(alias) elif arithmetic == FreeAggregateArithmetic.AVERAGE: return func.avg(literal_column(name)).label(alias) elif arithmetic == FreeAggregateArithmetic.MAXIMUM: return func.max(literal_column(name)).label(alias) elif arithmetic == FreeAggregateArithmetic.MINIMUM: return func.min(literal_column(name)).label(alias) elif arithmetic == FreeAggregateArithmetic.NONE or arithmetic is None: return label(alias, literal_column(name)) else: raise UnexpectedStorageException(f'Aggregate arithmetic[{arithmetic}] is not supported.')
def translate_straight_column_name( self, straight_column: EntityStraightColumn) -> Any: if isinstance(straight_column, EntityStraightAggregateColumn): if straight_column.arithmetic == EntityColumnAggregateArithmetic.COUNT: return func.count(straight_column.columnName) \ .label(self.get_alias_from_straight_column(straight_column)) elif straight_column.arithmetic == EntityColumnAggregateArithmetic.SUM: return func.sum(straight_column.columnName).label( self.get_alias_from_straight_column(straight_column)) elif straight_column.arithmetic == EntityColumnAggregateArithmetic.AVG: return func.avg(straight_column.columnName).label( self.get_alias_from_straight_column(straight_column)) elif straight_column.arithmetic == EntityColumnAggregateArithmetic.MAX: return func.max(straight_column.columnName).label( self.get_alias_from_straight_column(straight_column)) elif straight_column.arithmetic == EntityColumnAggregateArithmetic.MIN: return func.min(straight_column.columnName).label( self.get_alias_from_straight_column(straight_column)) elif isinstance(straight_column, EntityStraightColumn): return literal_column(straight_column.columnName) \ .label(self.get_alias_from_straight_column(straight_column)) raise UnsupportedStraightColumnException( f'Straight column[{straight_column.to_dict()}] is not supported.')
def _add_aggregation( cls, inner_query: Select, aggregation_columns: List[ColumnClause], group_by_columns: List[ColumnClause], grouping_sets: Optional[GroupingSets] = None, ) -> Select: """ Aggregates raw metric taxons. Groups by given dimension taxons or grouping sets. :param inner_query: Query to aggregate :param aggregation_columns: List of columns with applied aggregation function :param group_by_columns: List of columns to group by :param grouping_sets: Optional list of grouping sets to group by instead :return: Aggregated query """ if grouping_sets: # Because we union _PANORAMIC_GROUPINGSETS_NULL with column that can be date(time) or number, # we must cast all group columns to text. Some DB engines fail when we do casting and grouping in one query, # thus here we need to stringify the group columns in the CTE, and not in the group by query just below... group_by_column_names = {col.name for col in group_by_columns} stringified_group_columns = [] for col in inner_query.columns: if col.name in group_by_column_names: stringified_group_columns.append( cast(col, sqlalchemy.VARCHAR).label(col.name)) else: stringified_group_columns.append(col) # common table expression reused by multiple grouping sets queries cte_query = (Select( columns=sort_columns(stringified_group_columns)).select_from( inner_query).cte('__cte_grouping_sets')) grouping_sets_queries = [] for grouping_set in grouping_sets: safe_grouping_set = [ safe_identifier(col) for col in grouping_set ] # dimensions in the grouping set, used to aggregate values with group by gs_group_columns = [ col for col in group_by_columns if col.name in safe_grouping_set ] # extra dimensions not in the grouping set, returned as custom null values gs_null_columns = [ literal_column(f"'{_PANORAMIC_GROUPINGSETS_NULL}'").label( col.name) for col in group_by_columns if col.name not in safe_grouping_set ] grouping_sets_queries.append( Select(columns=sort_columns( gs_group_columns + gs_null_columns + aggregation_columns)).select_from(cte_query).group_by( *sort_columns(gs_group_columns))) return union_all(*grouping_sets_queries) # If grouping sets are not defined, use all dimensions for grouping. return (Select(columns=sort_columns( group_by_columns + aggregation_columns)).select_from(inner_query).group_by( *sort_columns(group_by_columns)))
def blend_dataframes( ctx: HuskyQueryContext, dataframes: List[Dataframe], data_source_formula_templates: Optional[Dict[str, List[SqlFormulaTemplate]]] = None, ) -> Dataframe: """ Produces new blended dataframe from all the given dataframes joined on all dimensions that appear at least twice in different dataframes. """ slug_to_dataframes: Dict[TaxonExpressionStr, List[Dataframe]] = _prepare_slug_to_dataframes(dataframes) dataframe_to_query: Dict[Dataframe, Selectable] = dict() used_model_names: Set[str] = set() used_physical_sources: Set[str] = set() for idx, df in enumerate(dataframes): # Create query for each dataframe, that has alias as 'q<number>' dataframe_to_query[df] = df.query.alias(f'q{idx}') used_model_names.update(df.used_model_names) used_physical_sources.update(df.used_physical_data_sources) selectors: List[TextClause] = [] dimension_columns: List[ColumnClause] = [] # Prepare list of sql selectors. If it is a metric, do zeroifnull(q0.metric + q1.metric + ...) # If it is a dimension, just select it. Because we are using USING clause, no need for coalesce. for taxon_slug in sorted(slug_to_dataframes.keys()): dataframes_with_slug = slug_to_dataframes[taxon_slug] taxon = dataframes_with_slug[0].slug_to_column[taxon_slug].taxon taxon_column = quote_identifier(taxon.slug_safe_sql_identifier, ctx.dialect) query_aliases = [dataframe_to_query[df].name for df in dataframes_with_slug] if taxon.is_dimension: if len(query_aliases) > 1: # Coalesce must have two or more args dimension_coalesce = functions.coalesce( *[literal_column(f'{query_alias}.{taxon_column}') for query_alias in query_aliases] ) else: # No need to coalesce now dimension_coalesce = literal_column(f'{query_aliases[0]}.{taxon_column}') col = dimension_coalesce.label(taxon.slug_safe_sql_identifier) dimension_columns.append(col) selectors.append(col) else: if taxon.data_source: # do not use coalesce aka zeroifnull when summing namespaces taxons.. # There are using TEL expr, where null is handled by TEL compilation. summed = '+'.join([f'{query_alias}.{taxon_column}' for query_alias in query_aliases]) else: summed = '+'.join([f'coalesce({query_alias}.{taxon_column},0)' for query_alias in query_aliases]) selectors.append(text(f'sum({summed}) as {taxon_column}')) final_columns: List[ColumnClause] = [] if data_source_formula_templates: for pre_formulas in data_source_formula_templates.values(): for pre_formula in pre_formulas: col = column(pre_formula.label) dimension_columns.append(col) selectors.append(col) final_columns.append(column(quote_identifier(pre_formula.label, ctx.dialect))) # All taxons in final DF final_slug_to_taxon: Dict[TaxonExpressionStr, DataframeColumn] = dataframes[0].slug_to_column.copy() # Because of sql alchemy compiler putting extra () around every using select_from, we first join all queries # And then define the aggregation selectors (right after this for loop) join_query = dataframe_to_query[dataframes[0]] for i in range(1, len(dataframes)): # Iterate dataframes, and do full outer join on FALSE, effectively meaning UNION-ALL without the need to # align all columns dataframe_to_join = dataframes[i] used_physical_sources.update(dataframe_to_join.used_physical_data_sources) final_slug_to_taxon = {**final_slug_to_taxon, **dataframe_to_join.slug_to_column} join_from = join_query join_to = dataframe_to_query[dataframe_to_join] # On purpose joining on value that will always return FALSE => PROD-8136 join_query = join_from.join( join_to, dataframe_to_query[dataframes[0]].columns[HUSKY_QUERY_DATA_SOURCE_COLUMN_NAME] == join_to.columns[HUSKY_QUERY_DATA_SOURCE_COLUMN_NAME], full=True, ) aggregate_join_query = select(selectors).select_from(join_query) for dimension_column in dimension_columns: aggregate_join_query = aggregate_join_query.group_by(dimension_column) # We have to wrap it in one more select, so the alchemy query object has columns referencable via 'c' attribute. final_columns.extend(column(id_) for id_ in safe_identifiers_iterable(final_slug_to_taxon.keys())) query = select(sort_columns(final_columns)).select_from(aggregate_join_query) return Dataframe(query, final_slug_to_taxon, used_model_names, used_physical_sources)
def _get_column_accessor_for_taxon_and_model( self, model: HuskyModel, taxon_slug_expression: TaxonSlugExpression) -> ColumnClause: return literal_column( model.taxon_sql_accessor(self.ctx, taxon_slug_expression.slug))
def _build_query_window_aggregations( self, taxon_to_model: Dict[TaxonSlugExpression, HuskyModel], ordered_query_joins: Sequence[QueryJoins], ) -> Select: """ Generates query for taxons which need window functions for aggregation :param taxon_to_model: Map of taxon slugs (key) and models they are coming from (value) :param ordered_query_joins: List of joins """ selectors = [] # generate inner query with window aggregation functions for taxon_slug_expression, taxon in sorted( self.projection_taxons.items(), key=lambda x: str(x[0])): model = taxon_to_model[taxon_slug_expression] if (taxon.tel_metadata and taxon.tel_metadata.aggregation_definition and taxon.tel_metadata.aggregation_definition.params and taxon.tel_metadata_aggregation_type in self._AGGREGATION_WINDOW_FUNCTIONS): # find the order_by columns order_by = [] window_params = cast( AggregationParamsSortDimension, taxon.tel_metadata.aggregation_definition.params) for field in window_params.sort_dimensions: col = taxon_to_model[TaxonSlugExpression( field.taxon)].taxon_sql_accessor( self.ctx, field.taxon) order_by_dir = field.order_by or TaxonOrderType.asc order_by.append( nullslast(ORDER_BY_FUNCTIONS[order_by_dir]( literal_column(col)))) # apply window aggregation functions column = self._AGGREGATION_WINDOW_FUNCTIONS[ taxon.tel_metadata_aggregation_type](literal_column( model.taxon_sql_accessor(self.ctx, taxon.slug))).over( partition_by=self.get_partition_by_columns(model), order_by=order_by) else: # otherwise, render the columns "as-is" column = literal_column( model.taxon_sql_accessor(self.ctx, taxon.slug)) selectors.append(column.label(taxon.slug_safe_sql_identifier)) # add joins to the inner query inner_query = select(selectors).select_from( self._build_from_joins(ordered_query_joins)) # apply scope filters to the inner query inner_query = ScopeGuard.add_scope_row_filters( self.ctx, self.scope, inner_query, self.taxon_model_info_map) # update taxon model info map, because we're selecting from outer query and not the inner query self._rebuild_taxon_info_map_inner_query() # then, we use prepare the outer query on which we can safely apply GROUP BY return self._build_selectors(lambda _, taxon_slug: safe_identifier( taxon_slug)).select_from(inner_query)