示例#1
0
    def _build_selectors(
        self,
        get_column_name: Callable[[TaxonSlugExpression, str], str],
    ) -> Select:
        """
        Returns the select part of query.
        """
        selectors = []

        for taxon_slug_expression, taxon in self.projection_taxons.items():
            column_name = get_column_name(taxon_slug_expression,
                                          taxon_slug_expression.slug)
            col = literal_column(column_name)

            selectors.append(col.label(taxon.slug_safe_sql_identifier))

        for template in self.dimension_templates:
            # We must render the dimension templates with correct sql columns
            slug_to_column = {
                slug: get_column_name(TaxonSlugExpression(slug), slug)
                for slug in template.used_taxons
            }
            sql_formula = template.render_formula(**slug_to_column)
            col = literal_column(sql_formula).label(template.label)
            selectors.append(col)
        return select(sort_columns(selectors))
示例#2
0
def _get_load_associations_query(document, doc_types_to_load):
    query_parents = DBSession. \
        query(
            Association.parent_document_id.label('id'),
            Document.type.label('t'),
            literal_column('1').label('p')). \
        join(
            Document,
            and_(
                Association.child_document_id == document.document_id,
                Association.parent_document_id == Document.document_id,
                Document.type.in_(doc_types_to_load))). \
        subquery()
    query_children = DBSession. \
        query(
            Association.child_document_id.label('id'),
            Document.type.label('t'),
            literal_column('0').label('p')). \
        join(
            Document,
            and_(
                Association.child_document_id == Document.document_id,
                Association.parent_document_id == document.document_id,
                Document.type.in_(doc_types_to_load))). \
        subquery()

    return DBSession \
        .query('id', 't', 'p') \
        .select_from(union(query_parents.select(), query_children.select()))
示例#3
0
def _get_load_associations_query(document, doc_types_to_load):
    query_parents = DBSession. \
        query(
            Association.parent_document_id.label('id'),
            Association.parent_document_type.label('t'),
            literal_column('1').label('p')). \
        filter(
            and_(
                Association.child_document_id == document.document_id,
                Association.parent_document_type.in_(doc_types_to_load)
            )
        ). \
        subquery()
    query_children = DBSession. \
        query(
            Association.child_document_id.label('id'),
            Association.child_document_type.label('t'),
            literal_column('0').label('p')). \
        filter(
            and_(
                Association.parent_document_id == document.document_id,
                Association.child_document_type.in_(doc_types_to_load)
            )
        ). \
        subquery()

    return DBSession \
        .query(column('id'), column('t'), column('p')) \
        .select_from(union(query_parents.select(), query_children.select()))
示例#4
0
 def get_partition_by_columns(self, model: HuskyModel):
     return [
         literal_column(
             model.taxon_sql_accessor(self.ctx, model_attribute.taxon))
         for model_attribute in model.attributes_memoized.values()
         if model_attribute.identifier is True
     ]
示例#5
0
    async def creator(
        gid: GroupID,
        cluster_access_rights: Dict[GroupID, ClusterAccessRights] = None
    ) -> Cluster:
        new_cluster = ClusterCreate(
            **{
                "name": faker.name(),
                "type": random.choice(list(ClusterType)),
                "owner": gid,
                "access_rights": cluster_access_rights or {},
            })

        result = postgres_db.execute(clusters.insert().values(
            new_cluster.dict(by_alias=True,
                             exclude={"id", "access_rights"
                                      })).returning(literal_column("*")))
        cluster_in_db = result.first()
        assert cluster_in_db is not None
        new_cluster_id = cluster_in_db[clusters.c.id]
        list_of_created_cluster_ids.append(new_cluster_id)

        # when a cluster is created, the DB automatically creates the owner access rights
        for group_id, access_rights in new_cluster.access_rights.items():
            result = postgres_db.execute(
                insert(cluster_to_groups).values(
                    **{
                        "cluster_id": new_cluster_id,
                        "gid": group_id,
                        "read": access_rights.read,
                        "write": access_rights.write,
                        "delete": access_rights.delete,
                    }).on_conflict_do_nothing())

        return Cluster(id=new_cluster_id,
                       **new_cluster.dict(by_alias=True, exclude={"id"}))
示例#6
0
async def test_listen_comp_tasks_task(
    mock_project_subsystem: Dict,
    comp_task_listening_task: asyncio.Task,
    client,
    update_values: Dict[str, Any],
    expected_calls: List[str],
    task_class: NodeClass,
):
    db_engine: aiopg.sa.Engine = client.app[APP_DB_ENGINE_KEY]
    async with db_engine.acquire() as conn:
        # let's put some stuff in there now
        result = await conn.execute(comp_tasks.insert().values(
            outputs=json.dumps({}),
            node_class=task_class).returning(literal_column("*")))
        row: RowProxy = await result.fetchone()
        task = dict(row)

        # let's update some values
        await conn.execute(comp_tasks.update().values(**update_values).where(
            comp_tasks.c.task_id == task["task_id"]))

        # tests whether listener gets hooked calls executed
        for call_name, mocked_call in mock_project_subsystem.items():
            if call_name in expected_calls:
                async for attempt in _async_retry_if_fails():
                    with attempt:
                        mocked_call.assert_awaited()

            else:
                mocked_call.assert_not_called()
async def test_listen_comp_tasks_task(
    mock_project_subsystem: Dict,
    comp_task_listening_task: asyncio.Task,
    client,
    upd_value: Dict[str, Any],
    exp_calls: List[str],
    task_class: NodeClass,
):
    db_engine: aiopg.sa.Engine = client.app[APP_DB_ENGINE_KEY]
    async with db_engine.acquire() as conn:
        # let's put some stuff in there now
        result = await conn.execute(
            comp_tasks.insert()
            .values(outputs=json.dumps({}), node_class=task_class)
            .returning(literal_column("*"))
        )
        row: RowProxy = await result.fetchone()
        task = dict(row)

        # let's update some values
        await conn.execute(
            comp_tasks.update()
            .values(**upd_value)
            .where(comp_tasks.c.task_id == task["task_id"])
        )
        for key, mock_fct in mock_project_subsystem.items():
            if key in exp_calls:
                await _wait_for_call(mock_fct)
            else:
                mock_fct.assert_not_called()
示例#8
0
    async def create(
        self,
        user_id: UserID,
        project_id: ProjectID,
        cluster_id: ClusterID,
        default_cluster_id: ClusterID,
        iteration: Optional[PositiveInt] = None,
    ) -> CompRunsAtDB:
        async with self.db_engine.acquire() as conn:
            if iteration is None:
                # let's get the latest if it exists
                last_iteration = await conn.scalar(
                    sa.select([
                        comp_runs.c.iteration
                    ]).where((comp_runs.c.user_id == user_id)
                             & (comp_runs.c.project_uuid == str(project_id))).
                    order_by(desc(comp_runs.c.iteration)))
                iteration = (last_iteration or 1) + 1

            result = await conn.execute(
                comp_runs.insert()  # pylint: disable=no-value-for-parameter
                .values(
                    user_id=user_id,
                    project_uuid=f"{project_id}",
                    cluster_id=cluster_id
                    if cluster_id != default_cluster_id else None,
                    iteration=iteration,
                    result=RUNNING_STATE_TO_DB[RunningState.PUBLISHED],
                    started=datetime.utcnow(),
                ).returning(literal_column("*")))
            row = await result.first()
            return CompRunsAtDB.from_orm(row)
示例#9
0
def select_vc_time_per_user(type_id: int, lit_values: list) -> Select:
    join_time = date_to_secs(VoiceChatEvent.created_at)
    left_time = date_to_secs(VoiceChatEvent.updated_at)
    value_column = func.sum(left_time - join_time).label('value')
    lit_columns = [literal_column(str(v)).label(l) for (l, v) in lit_values]
    select_columns = [value_column, VoiceChatEvent.user_id] + lit_columns
    return select(select_columns).where(
        VoiceChatEvent.type_id == type_id).group_by(VoiceChatEvent.user_id)
示例#10
0
async def create_user(db, data):
    async with db() as conn:
        created = await conn.execute(user.insert()
            .returning(literal_column('id, email'))
            .values(**data))
            
        new_user = await created.fetchone()
        return dict(new_user)
示例#11
0
def select_membership_time_per_user(type_id: int, lit_values: list) -> Select:
    join_time = date_to_secs(func.max(MemberEvent.created_at))
    current_time = int(datetime.now().timestamp())
    membership_value = cast((current_time - join_time) / 86400,
                            Integer).label('value')
    lit_columns = [literal_column(str(v)).label(l) for (l, v) in lit_values]
    select_columns = [membership_value, MemberEvent.user_id] + lit_columns
    return select(select_columns).where(
        and_(MemberEvent.type_id == type_id,
             User.roles != None)).group_by(MemberEvent.user_id)
示例#12
0
文件: script.py 项目: karimould/til-1
def get_actors(engine):
    q = select(
        [literal_column('actor_id')]
    ) \
        .select_from(table('actor')) \
        .where(
        and_(
            literal_column('first_name') == 'JOHNNY',
            literal_column('last_name') == 'LOLLOBRIGIDA',
        )
    )
    session = get_session(engine=engine)
    try:
        rsp = session.execute(q)
        res = rsp.first()
    finally:
        print('do finally')
        session.close()

    return res
示例#13
0
def get_neighbour_version_ids(version_id, document_id, lang):
    """
    Get the previous and next version for a version of a document with a
    specific language.
    """
    next_version = DBSession \
        .query(
            DocumentVersion.id.label('id'),
            literal_column('1').label('t')) \
        .filter(DocumentVersion.id > version_id) \
        .filter(DocumentVersion.document_id == document_id) \
        .filter(DocumentVersion.lang == lang) \
        .order_by(DocumentVersion.id) \
        .limit(1) \
        .subquery()

    previous_version = DBSession \
        .query(
            DocumentVersion.id.label('id'),
            literal_column('-1').label('t')) \
        .filter(DocumentVersion.id < version_id) \
        .filter(DocumentVersion.document_id == document_id) \
        .filter(DocumentVersion.lang == lang) \
        .order_by(DocumentVersion.id.desc()) \
        .limit(1) \
        .subquery()

    query = DBSession \
        .query('id', 't') \
        .select_from(union(
            next_version.select(), previous_version.select()))

    previous_version_id = None
    next_version_id = None
    for version, typ in query:
        if typ == -1:
            previous_version_id = version
        else:
            next_version_id = version

    return previous_version_id, next_version_id
示例#14
0
def _get_select_children(waypoint):
    """
    Return a WITH query that selects the document ids of the given waypoint,
    the children and the grand-children of the waypoint.
    See also: http://docs.sqlalchemy.org/en/latest/core/selectable.html#sqlalchemy.sql.expression.GenerativeSelect.cte  # noqa
    """
    select_waypoint = DBSession. \
        query(
            literal_column(str(waypoint.document_id)).label('document_id'),
            literal_column('1').label('priority')). \
        cte('waypoint')
    # query to get the direct child waypoints
    select_waypoint_children = DBSession. \
        query(
            Waypoint.document_id,
            literal_column('0').label('priority')). \
        join(
            Association,
            and_(Association.child_document_id == Waypoint.document_id,
                 Association.parent_document_id == waypoint.document_id)). \
        cte('waypoint_children')
    # query to get the grand-child waypoints
    select_waypoint_grandchildren = DBSession. \
        query(
            Waypoint.document_id,
            literal_column('0').label('priority')). \
        select_from(select_waypoint_children). \
        join(
            Association,
            Association.parent_document_id ==
            select_waypoint_children.c.document_id). \
        join(
            Waypoint,
            Association.child_document_id == Waypoint.document_id). \
        cte('waypoint_grandchildren')

    return union(
            select_waypoint.select(),
            select_waypoint_children.select(),
            select_waypoint_grandchildren.select()). \
        cte('select_all_waypoints')
示例#15
0
def get_neighbour_version_ids(version_id, document_id, lang):
    """
    Get the previous and next version for a version of a document with a
    specific language.
    """
    next_version = DBSession \
        .query(
            DocumentVersion.id.label('id'),
            literal_column('1').label('t')) \
        .filter(DocumentVersion.id > version_id) \
        .filter(DocumentVersion.document_id == document_id) \
        .filter(DocumentVersion.lang == lang) \
        .order_by(DocumentVersion.id) \
        .limit(1) \
        .subquery()

    previous_version = DBSession \
        .query(
            DocumentVersion.id.label('id'),
            literal_column('-1').label('t')) \
        .filter(DocumentVersion.id < version_id) \
        .filter(DocumentVersion.document_id == document_id) \
        .filter(DocumentVersion.lang == lang) \
        .order_by(DocumentVersion.id.desc()) \
        .limit(1) \
        .subquery()

    query = DBSession \
        .query('id', 't') \
        .select_from(union(
            next_version.select(), previous_version.select()))

    previous_version_id = None
    next_version_id = None
    for version, typ in query:
        if typ == -1:
            previous_version_id = version
        else:
            next_version_id = version

    return previous_version_id, next_version_id
示例#16
0
def _get_select_children(waypoint):
    """
    Return a WITH query that selects the document ids of the given waypoint,
    the children and the grand-children of the waypoint.
    See also: http://docs.sqlalchemy.org/en/latest/core/selectable.html#sqlalchemy.sql.expression.GenerativeSelect.cte  # noqa
    """
    select_waypoint = DBSession. \
        query(
            literal_column(str(waypoint.document_id)).label('document_id'),
            literal_column('1').label('priority')). \
        cte('waypoint')
    # query to get the direct child waypoints
    select_waypoint_children = DBSession. \
        query(
            Waypoint.document_id,
            literal_column('0').label('priority')). \
        join(
            Association,
            and_(Association.child_document_id == Waypoint.document_id,
                 Association.parent_document_id == waypoint.document_id)). \
        cte('waypoint_children')
    # query to get the grand-child waypoints
    select_waypoint_grandchildren = DBSession. \
        query(
            Waypoint.document_id,
            literal_column('0').label('priority')). \
        select_from(select_waypoint_children). \
        join(
            Association,
            Association.parent_document_id ==
            select_waypoint_children.c.document_id). \
        join(
            Waypoint,
            Association.child_document_id == Waypoint.document_id). \
        cte('waypoint_grandchildren')

    return union(
            select_waypoint.select(),
            select_waypoint_children.select(),
            select_waypoint_grandchildren.select()). \
        cte('select_all_waypoints')
示例#17
0
 async def update(self, user_id: UserID, project_id: ProjectID,
                  iteration: PositiveInt,
                  **values) -> Optional[CompRunsAtDB]:
     async with self.db_engine.acquire() as conn:
         result = await conn.execute(
             sa.update(comp_runs).where(
                 (comp_runs.c.project_uuid == str(project_id))
                 & (comp_runs.c.user_id == str(user_id))
                 & (comp_runs.c.iteration == iteration)).values(
                     **values).returning(literal_column("*")))
         row: RowProxy = await result.first()
         return CompRunsAtDB.from_orm(row) if row else None
示例#18
0
    def __iter__(self):
        """Main interface to get the data, returns assets and then policies."""

        organization, folders, tables, policies, group_membership = \
            create_table_names(self.snapshot.cycle_timestamp)

        forseti_org = self.session.query(organization).one()
        yield "organizations", forseti_org

        # Folders
        folder_set = (self.session.query(folders).filter(
            folders.parent_type == 'organization').all())

        while folder_set:
            for folder in folder_set:
                yield 'folders', folder

            folder_set = (self.session.query(folders).filter(
                folders.parent_type == 'folder').filter(
                    folders.parent_id.in_([f.folder_id
                                           for f in folder_set])).all())

        for res_type, table in tables:
            for item in self.session.query(table).yield_per(PER_YIELD):
                yield res_type, item

        membership, groups = group_membership
        query_groups = (self.session.query(groups).with_entities(
            literal_column("'GROUP'"), groups.group_email))
        principals = query_groups.distinct()
        for kind, email in principals.yield_per(PER_YIELD):
            yield kind.lower(), email

        query = (self.session.query(
            membership,
            groups).filter(membership.group_id == groups.group_id).order_by(
                desc(membership.member_email)).distinct())

        cur_member = None
        member_groups = []
        for member, group in query.yield_per(PER_YIELD):
            if cur_member and cur_member.member_email != member.member_email:
                if cur_member:
                    yield 'membership', (cur_member, member_groups)
                    cur_member = None
                    member_groups = []

            cur_member = member
            member_groups.append(group)

        for policy_table in policies:
            for policy in self.session.query(policy_table).all():
                yield 'policy', policy
示例#19
0
	def build_aggregate_group_by(
			self, table_columns: List[FreeAggregateColumn], base_statement: SQLAlchemyStatement
	) -> Tuple[bool, SQLAlchemyStatement]:
		non_group_columns = ArrayHelper(table_columns) \
			.filter(lambda x: x.arithmetic is None or x.arithmetic == FreeAggregateArithmetic.NONE) \
			.to_list()
		if len(non_group_columns) != 0 and len(non_group_columns) != len(table_columns):
			# only when aggregation column exists, group by needs to be appended
			return True, base_statement.group_by(
				*ArrayHelper(non_group_columns).map(lambda x: literal_column(x.name)).to_list())
		else:
			# otherwise return statement itself
			return False, base_statement
示例#20
0
文件: waypoint.py 项目: c2corg/v6_api
def _get_select_children(waypoint):
    """
    Return a WITH query that selects the document ids of the given waypoint,
    the children and the grand-children of the waypoint.
    See also: http://docs.sqlalchemy.org/en/latest/core/selectable.html#sqlalchemy.sql.expression.GenerativeSelect.cte  # noqa
    """
    select_waypoint = DBSession.query(
        literal_column(str(waypoint.document_id)).label("document_id"), literal_column("1").label("priority")
    ).cte("waypoint")
    # query to get the direct child waypoints
    select_waypoint_children = (
        DBSession.query(Association.child_document_id.label("document_id"), literal_column("0").label("priority"))
        .filter(
            and_(
                Association.child_document_type == WAYPOINT_TYPE, Association.parent_document_id == waypoint.document_id
            )
        )
        .cte("waypoint_children")
    )
    # query to get the grand-child waypoints
    select_waypoint_grandchildren = (
        DBSession.query(Association.child_document_id.label("document_id"), literal_column("0").label("priority"))
        .select_from(select_waypoint_children)
        .join(
            Association,
            and_(
                Association.parent_document_id == select_waypoint_children.c.document_id,
                Association.child_document_type == WAYPOINT_TYPE,
            ),
        )
        .cte("waypoint_grandchildren")
    )

    return union(
        select_waypoint.select(), select_waypoint_children.select(), select_waypoint_grandchildren.select()
    ).cte("select_all_waypoints")
示例#21
0
async def task(
    db_connection: SAConnection,
    db_notification_queue: asyncio.Queue,
    task_class: NodeClass,
) -> Dict:
    result = await db_connection.execute(comp_tasks.insert().values(
        outputs=json.dumps({}),
        node_class=task_class).returning(literal_column("*")))
    row: RowProxy = await result.fetchone()
    task = dict(row)

    assert (db_notification_queue.empty(
    )), "database triggered change although it should only trigger on updates!"

    yield task
示例#22
0
def update_areas_of_changes(document):
    """Update the area ids of all feed entries of the given document.
    """
    areas_select = select(
            [
                # concatenate with empty array to avoid null values
                # select ARRAY[]::integer[] || array_agg(area_id)
                literal_column('ARRAY[]::integer[]').op('||')(
                    func.array_agg(
                        AreaAssociation.area_id,
                        type_=postgresql.ARRAY(Integer)))
            ]).\
        where(AreaAssociation.document_id == document.document_id)

    DBSession.execute(DocumentChange.__table__.update().where(
        DocumentChange.document_id == document.document_id).values(
            area_ids=areas_select.as_scalar()))
示例#23
0
    def src_vector_intersects(self) -> bool:

        try:
            logger.debug(
                f"Check if tile {self.tile_id} intersects with postgis table")
            conn = psycopg2.connect(
                dbname=self.src.conn.db_name,
                user=self.src.conn.db_user,
                password=self.src.conn.db_password,
                host=self.src.conn.db_host,
                port=self.src.conn.db_port,
            )

            cursor = conn.cursor()

            sql = (select([literal_column("1")]).select_from(
                self.src_table()).where(self.intersect_filter()))
            # exists_query = select([literal_column("exists")]).select_from(select_1)

            logger.debug(str(sql))

            cursor.execute(str(sql))

            try:
                exists = bool(cursor.fetchone()[0])
            except (ProgrammingError, TypeError):
                exists = False

            cursor.close()
            conn.close()
        except psycopg2.Error:
            logger.exception(
                "There was an issue when trying to connect to the database")
            raise

        logger.debug(f"EXISTS: {exists}")

        if exists:
            logger.info(
                f"Tile id {self.tile_id} exists in database table {self.src.schema}.{self.src.table}"
            )
        else:
            logger.info(
                f"Tile id {self.tile_id} does not exists in database table {self.src.schema}.{self.src.table}"
            )
        return exists
示例#24
0
def update_map(topo_map, reset=False):
    """Create associations for the given map with all intersecting documents.

    If `reset` is True, all possible existing associations to this map are
    dropped before creating new associations.
    """
    if reset:
        DBSession.execute(
            TopoMapAssociation.__table__.delete().where(
                TopoMapAssociation.topo_map_id == topo_map.document_id)
        )

    if topo_map.redirects_to:
        # ignore forwarded maps
        return

    map_geom = select([DocumentGeometry.geom_detail]). \
        where(DocumentGeometry.document_id == topo_map.document_id)
    intersecting_documents = DBSession. \
        query(
            DocumentGeometry.document_id,  # id of a document
            literal_column(str(topo_map.document_id))). \
        join(
            Document,
            and_(
                Document.document_id == DocumentGeometry.document_id,
                Document.type != MAP_TYPE)). \
        filter(Document.redirects_to.is_(None)). \
        filter(
            or_(
                DocumentGeometry.geom.ST_Intersects(
                    map_geom.label('t1')),
                DocumentGeometry.geom_detail.ST_Intersects(
                    map_geom.label('t2'))
            ))

    DBSession.execute(
        TopoMapAssociation.__table__.insert().from_select(
            [TopoMapAssociation.document_id, TopoMapAssociation.topo_map_id],
            intersecting_documents))

    # update cache key for now associated docs
    update_cache_version_for_map(topo_map)
示例#25
0
文件: feed.py 项目: c2corg/v6_api
def update_areas_of_changes(document):
    """Update the area ids of all feed entries of the given document.
    """
    areas_select = select(
            [
                # concatenate with empty array to avoid null values
                # select ARRAY[]::integer[] || array_agg(area_id)
                literal_column('ARRAY[]::integer[]').op('||')(
                    func.array_agg(
                        AreaAssociation.area_id,
                        type_=postgresql.ARRAY(Integer)))
            ]).\
        where(AreaAssociation.document_id == document.document_id)

    DBSession.execute(
        DocumentChange.__table__.update().
        where(DocumentChange.document_id == document.document_id).
        values(area_ids=areas_select.as_scalar())
    )
示例#26
0
    def _append_returning(self, columns: Union[str, List[str]],
                          query: UpdateBase) -> Tuple[UpdateBase, bool]:
        column_names: List[str] = _normalize(columns)

        is_scalar: bool = len(column_names) == 1

        if PRIMARY_KEY in column_names:
            # defaults to primery key
            query = query.returning(self._primary_key)

        elif ALL_COLUMNS in column_names:
            query = query.returning(literal_column("*"))
            is_scalar = False
            # NOTE: returning = self._table would also work. less efficient?
        else:
            # selection
            query = query.returning(
                *[self._table.c[name] for name in column_names])

        return query, is_scalar
示例#27
0
async def test_listen_comp_tasks_task(
    mock_project_subsystem: Dict,
    comp_task_listening_task: None,
    client,
    update_values: Dict[str, Any],
    expected_calls: List[str],
    task_class: NodeClass,
):
    db_engine: aiopg.sa.Engine = client.app[APP_DB_ENGINE_KEY]
    async with db_engine.acquire() as conn:
        # let's put some stuff in there now
        result = await conn.execute(
            comp_tasks.insert()
            .values(outputs=json.dumps({}), node_class=task_class)
            .returning(literal_column("*"))
        )
        row: RowProxy = await result.fetchone()
        task = dict(row)

        # let's update some values
        await conn.execute(
            comp_tasks.update()
            .values(**update_values)
            .where(comp_tasks.c.task_id == task["task_id"])
        )

        # tests whether listener gets hooked calls executed
        for call_name, mocked_call in mock_project_subsystem.items():
            if call_name in expected_calls:
                async for attempt in AsyncRetrying(
                    wait=wait_fixed(1),
                    stop=stop_after_delay(10),
                    retry=retry_if_exception_type(AssertionError),
                    before_sleep=before_sleep_log(logger, logging.INFO),
                    reraise=True,
                ):
                    with attempt:
                        mocked_call.assert_awaited()

            else:
                mocked_call.assert_not_called()
示例#28
0
def update_area(area, reset=False):
    """Create associations for the given area with all intersecting documents.

    If `reset` is True, all possible existing associations to this area are
    dropped before creating new associations.
    """
    if reset:
        DBSession.execute(
            AreaAssociation.__table__.delete().where(
                AreaAssociation.area_id == area.document_id)
        )

    if area.redirects_to:
        # ignore forwarded areas
        return

    area_geom = select([DocumentGeometry.geom_detail]). \
        where(DocumentGeometry.document_id == area.document_id)
    intersecting_documents = DBSession. \
        query(
            DocumentGeometry.document_id,  # id of a document
            literal_column(str(area.document_id))). \
        join(
            Document,
            and_(
                Document.document_id == DocumentGeometry.document_id,
                Document.type != AREA_TYPE)). \
        filter(Document.redirects_to.is_(None)). \
        filter(
            or_(
                DocumentGeometry.geom.ST_Intersects(
                    area_geom.label('t1')),
                DocumentGeometry.geom_detail.ST_Intersects(
                    area_geom.label('t2'))
            ))

    DBSession.execute(
        AreaAssociation.__table__.insert().from_select(
            [AreaAssociation.document_id, AreaAssociation.area_id],
            intersecting_documents))
示例#29
0
def update_maps_for_document(document, reset=False):
    """Create associations for the given documents with all intersecting maps.

    If `reset` is True, all possible existing associations to this document are
    dropped before creating new associations.
    """
    if reset:
        DBSession.execute(
            TopoMapAssociation.__table__.delete().where(
                TopoMapAssociation.document_id == document.document_id)
        )

    if document.redirects_to:
        # ignore forwarded maps
        return

    document_geom = select([DocumentGeometry.geom]). \
        where(DocumentGeometry.document_id == document.document_id)
    document_geom_detail = select([DocumentGeometry.geom_detail]). \
        where(DocumentGeometry.document_id == document.document_id)
    intersecting_maps = DBSession. \
        query(
            DocumentGeometry.document_id,  # id of a map
            literal_column(str(document.document_id))). \
        join(
            TopoMap,
            TopoMap.document_id == DocumentGeometry.document_id). \
        filter(TopoMap.redirects_to.is_(None)). \
        filter(
            or_(
                DocumentGeometry.geom_detail.ST_Intersects(
                    document_geom.label('t1')),
                DocumentGeometry.geom_detail.ST_Intersects(
                    document_geom_detail.label('t2'))
            ))

    DBSession.execute(
        TopoMapAssociation.__table__.insert().from_select(
            [TopoMapAssociation.topo_map_id, TopoMapAssociation.document_id],
            intersecting_maps))
示例#30
0
	def build_free_aggregate_column(self, table_column: FreeAggregateColumn, index: int, prefix_name: str) -> Label:
		name = table_column.name
		alias = f'{prefix_name}_{index + 1}'
		arithmetic = table_column.arithmetic
		if arithmetic == FreeAggregateArithmetic.COUNT:
			return func.count(literal_column(name)).label(alias)
		elif arithmetic == FreeAggregateArithmetic.SUMMARY:
			return func.sum(literal_column(name)).label(alias)
		elif arithmetic == FreeAggregateArithmetic.AVERAGE:
			return func.avg(literal_column(name)).label(alias)
		elif arithmetic == FreeAggregateArithmetic.MAXIMUM:
			return func.max(literal_column(name)).label(alias)
		elif arithmetic == FreeAggregateArithmetic.MINIMUM:
			return func.min(literal_column(name)).label(alias)
		elif arithmetic == FreeAggregateArithmetic.NONE or arithmetic is None:
			return label(alias, literal_column(name))
		else:
			raise UnexpectedStorageException(f'Aggregate arithmetic[{arithmetic}] is not supported.')
示例#31
0
    def translate_straight_column_name(
            self, straight_column: EntityStraightColumn) -> Any:
        if isinstance(straight_column, EntityStraightAggregateColumn):
            if straight_column.arithmetic == EntityColumnAggregateArithmetic.COUNT:
                return func.count(straight_column.columnName) \
                 .label(self.get_alias_from_straight_column(straight_column))
            elif straight_column.arithmetic == EntityColumnAggregateArithmetic.SUM:
                return func.sum(straight_column.columnName).label(
                    self.get_alias_from_straight_column(straight_column))
            elif straight_column.arithmetic == EntityColumnAggregateArithmetic.AVG:
                return func.avg(straight_column.columnName).label(
                    self.get_alias_from_straight_column(straight_column))
            elif straight_column.arithmetic == EntityColumnAggregateArithmetic.MAX:
                return func.max(straight_column.columnName).label(
                    self.get_alias_from_straight_column(straight_column))
            elif straight_column.arithmetic == EntityColumnAggregateArithmetic.MIN:
                return func.min(straight_column.columnName).label(
                    self.get_alias_from_straight_column(straight_column))
        elif isinstance(straight_column, EntityStraightColumn):
            return literal_column(straight_column.columnName) \
             .label(self.get_alias_from_straight_column(straight_column))

        raise UnsupportedStraightColumnException(
            f'Straight column[{straight_column.to_dict()}] is not supported.')
    def _add_aggregation(
        cls,
        inner_query: Select,
        aggregation_columns: List[ColumnClause],
        group_by_columns: List[ColumnClause],
        grouping_sets: Optional[GroupingSets] = None,
    ) -> Select:
        """
        Aggregates raw metric taxons. Groups by given dimension taxons or grouping sets.

        :param inner_query: Query to aggregate
        :param aggregation_columns: List of columns with applied aggregation function
        :param group_by_columns: List of columns to group by
        :param grouping_sets: Optional list of grouping sets to group by instead
        :return: Aggregated query
        """
        if grouping_sets:
            # Because we union _PANORAMIC_GROUPINGSETS_NULL with column that can be date(time) or number,
            # we must cast all group columns to text. Some DB engines fail when we do casting and grouping in one query,
            # thus here we need to stringify the group columns in the CTE, and not in the group by query just below...
            group_by_column_names = {col.name for col in group_by_columns}
            stringified_group_columns = []
            for col in inner_query.columns:
                if col.name in group_by_column_names:
                    stringified_group_columns.append(
                        cast(col, sqlalchemy.VARCHAR).label(col.name))
                else:
                    stringified_group_columns.append(col)

            # common table expression reused by multiple grouping sets queries
            cte_query = (Select(
                columns=sort_columns(stringified_group_columns)).select_from(
                    inner_query).cte('__cte_grouping_sets'))
            grouping_sets_queries = []

            for grouping_set in grouping_sets:
                safe_grouping_set = [
                    safe_identifier(col) for col in grouping_set
                ]
                # dimensions in the grouping set, used to aggregate values with group by
                gs_group_columns = [
                    col for col in group_by_columns
                    if col.name in safe_grouping_set
                ]
                # extra dimensions not in the grouping set, returned as custom null values
                gs_null_columns = [
                    literal_column(f"'{_PANORAMIC_GROUPINGSETS_NULL}'").label(
                        col.name) for col in group_by_columns
                    if col.name not in safe_grouping_set
                ]
                grouping_sets_queries.append(
                    Select(columns=sort_columns(
                        gs_group_columns + gs_null_columns +
                        aggregation_columns)).select_from(cte_query).group_by(
                            *sort_columns(gs_group_columns)))
            return union_all(*grouping_sets_queries)

        # If grouping sets are not defined, use all dimensions for grouping.
        return (Select(columns=sort_columns(
            group_by_columns +
            aggregation_columns)).select_from(inner_query).group_by(
                *sort_columns(group_by_columns)))
示例#33
0
def blend_dataframes(
    ctx: HuskyQueryContext,
    dataframes: List[Dataframe],
    data_source_formula_templates: Optional[Dict[str, List[SqlFormulaTemplate]]] = None,
) -> Dataframe:
    """
    Produces new blended dataframe from all the given dataframes joined on all dimensions that appear at least twice in
    different dataframes.
    """
    slug_to_dataframes: Dict[TaxonExpressionStr, List[Dataframe]] = _prepare_slug_to_dataframes(dataframes)
    dataframe_to_query: Dict[Dataframe, Selectable] = dict()
    used_model_names: Set[str] = set()
    used_physical_sources: Set[str] = set()
    for idx, df in enumerate(dataframes):
        # Create query for each dataframe, that has alias as 'q<number>'
        dataframe_to_query[df] = df.query.alias(f'q{idx}')
        used_model_names.update(df.used_model_names)
        used_physical_sources.update(df.used_physical_data_sources)

    selectors: List[TextClause] = []
    dimension_columns: List[ColumnClause] = []
    # Prepare list of sql selectors. If it is a metric, do zeroifnull(q0.metric + q1.metric + ...)
    # If it is a dimension, just select it. Because we are using USING clause, no need for coalesce.
    for taxon_slug in sorted(slug_to_dataframes.keys()):
        dataframes_with_slug = slug_to_dataframes[taxon_slug]
        taxon = dataframes_with_slug[0].slug_to_column[taxon_slug].taxon
        taxon_column = quote_identifier(taxon.slug_safe_sql_identifier, ctx.dialect)
        query_aliases = [dataframe_to_query[df].name for df in dataframes_with_slug]
        if taxon.is_dimension:
            if len(query_aliases) > 1:
                # Coalesce must have two or more args
                dimension_coalesce = functions.coalesce(
                    *[literal_column(f'{query_alias}.{taxon_column}') for query_alias in query_aliases]
                )
            else:
                #  No need to coalesce now
                dimension_coalesce = literal_column(f'{query_aliases[0]}.{taxon_column}')
            col = dimension_coalesce.label(taxon.slug_safe_sql_identifier)

            dimension_columns.append(col)
            selectors.append(col)
        else:
            if taxon.data_source:
                # do not use coalesce aka zeroifnull when summing namespaces taxons..
                # There are using TEL expr, where null is handled by TEL compilation.
                summed = '+'.join([f'{query_alias}.{taxon_column}' for query_alias in query_aliases])
            else:
                summed = '+'.join([f'coalesce({query_alias}.{taxon_column},0)' for query_alias in query_aliases])
            selectors.append(text(f'sum({summed}) as {taxon_column}'))

    final_columns: List[ColumnClause] = []
    if data_source_formula_templates:
        for pre_formulas in data_source_formula_templates.values():
            for pre_formula in pre_formulas:
                col = column(pre_formula.label)
                dimension_columns.append(col)
                selectors.append(col)
                final_columns.append(column(quote_identifier(pre_formula.label, ctx.dialect)))

    # All taxons in final DF
    final_slug_to_taxon: Dict[TaxonExpressionStr, DataframeColumn] = dataframes[0].slug_to_column.copy()

    # Because of sql alchemy compiler putting extra () around every using select_from, we first join all queries
    # And then define the aggregation selectors (right after this for loop)
    join_query = dataframe_to_query[dataframes[0]]
    for i in range(1, len(dataframes)):
        #  Iterate dataframes, and do full outer join on FALSE, effectively meaning UNION-ALL without the need to
        # align all columns
        dataframe_to_join = dataframes[i]
        used_physical_sources.update(dataframe_to_join.used_physical_data_sources)

        final_slug_to_taxon = {**final_slug_to_taxon, **dataframe_to_join.slug_to_column}
        join_from = join_query
        join_to = dataframe_to_query[dataframe_to_join]

        # On purpose joining on value that will always return FALSE => PROD-8136
        join_query = join_from.join(
            join_to,
            dataframe_to_query[dataframes[0]].columns[HUSKY_QUERY_DATA_SOURCE_COLUMN_NAME]
            == join_to.columns[HUSKY_QUERY_DATA_SOURCE_COLUMN_NAME],
            full=True,
        )

    aggregate_join_query = select(selectors).select_from(join_query)
    for dimension_column in dimension_columns:
        aggregate_join_query = aggregate_join_query.group_by(dimension_column)

    # We have to wrap it in one more select, so the alchemy query object has columns referencable via 'c' attribute.
    final_columns.extend(column(id_) for id_ in safe_identifiers_iterable(final_slug_to_taxon.keys()))
    query = select(sort_columns(final_columns)).select_from(aggregate_join_query)

    return Dataframe(query, final_slug_to_taxon, used_model_names, used_physical_sources)
示例#34
0
 def _get_column_accessor_for_taxon_and_model(
         self, model: HuskyModel,
         taxon_slug_expression: TaxonSlugExpression) -> ColumnClause:
     return literal_column(
         model.taxon_sql_accessor(self.ctx, taxon_slug_expression.slug))
示例#35
0
    def _build_query_window_aggregations(
        self,
        taxon_to_model: Dict[TaxonSlugExpression, HuskyModel],
        ordered_query_joins: Sequence[QueryJoins],
    ) -> Select:
        """
        Generates query for taxons which need window functions for aggregation

        :param taxon_to_model: Map of taxon slugs (key) and models they are coming from (value)
        :param ordered_query_joins: List of joins
        """
        selectors = []
        # generate inner query with window aggregation functions
        for taxon_slug_expression, taxon in sorted(
                self.projection_taxons.items(), key=lambda x: str(x[0])):
            model = taxon_to_model[taxon_slug_expression]
            if (taxon.tel_metadata
                    and taxon.tel_metadata.aggregation_definition
                    and taxon.tel_metadata.aggregation_definition.params
                    and taxon.tel_metadata_aggregation_type
                    in self._AGGREGATION_WINDOW_FUNCTIONS):
                # find the order_by columns
                order_by = []
                window_params = cast(
                    AggregationParamsSortDimension,
                    taxon.tel_metadata.aggregation_definition.params)
                for field in window_params.sort_dimensions:
                    col = taxon_to_model[TaxonSlugExpression(
                        field.taxon)].taxon_sql_accessor(
                            self.ctx, field.taxon)

                    order_by_dir = field.order_by or TaxonOrderType.asc
                    order_by.append(
                        nullslast(ORDER_BY_FUNCTIONS[order_by_dir](
                            literal_column(col))))

                # apply window aggregation functions
                column = self._AGGREGATION_WINDOW_FUNCTIONS[
                    taxon.tel_metadata_aggregation_type](literal_column(
                        model.taxon_sql_accessor(self.ctx, taxon.slug))).over(
                            partition_by=self.get_partition_by_columns(model),
                            order_by=order_by)
            else:
                # otherwise, render the columns "as-is"
                column = literal_column(
                    model.taxon_sql_accessor(self.ctx, taxon.slug))

            selectors.append(column.label(taxon.slug_safe_sql_identifier))

        # add joins to the inner query
        inner_query = select(selectors).select_from(
            self._build_from_joins(ordered_query_joins))

        # apply scope filters to the inner query
        inner_query = ScopeGuard.add_scope_row_filters(
            self.ctx, self.scope, inner_query, self.taxon_model_info_map)

        # update taxon model info map, because we're selecting from outer query and not the inner query
        self._rebuild_taxon_info_map_inner_query()

        # then, we use prepare the outer query on which we can safely apply GROUP BY
        return self._build_selectors(lambda _, taxon_slug: safe_identifier(
            taxon_slug)).select_from(inner_query)