Exemplo n.º 1
0
    def test_get_object(self):
        """
            Test the method _get_object() using a sync key
            Test scenario:
            Get the object with sync_key
        """

        obj_type = self.powerVCMapping.obj_type
        sync_key = self.powerVCMapping.sync_key

        self.aMox.StubOutWithMock(session, 'query')
        session.query(model.PowerVCMapping).AndReturn(query)

        self.aMox.StubOutWithMock(query, 'filter_by')
        query.filter_by(
            obj_type=obj_type, sync_key=sync_key).AndReturn(query)

        self.aMox.StubOutWithMock(query, 'one')
        query.one().AndReturn(self.powerVCMapping)

        self.aMox.ReplayAll()
        returnValue = self.powervcagentdb._get_object(
            obj_type=obj_type, sync_key=sync_key)
        self.aMox.VerifyAll()
        self.assertEqual(returnValue, self.powerVCMapping)
        self.aMox.UnsetStubs()
Exemplo n.º 2
0
    def get_scenarios(self, scenarios_filter=None, pagination=None):
        """Search for scenarios by filter.

        :param scenarios_filter: instance of :class:`ScenarioFilter
            <autostorage.core.scenario.param_spec.ScenarioFilter>`.
        :param pagination: instance of `Pagination <autostorage.core.param_spec.Pagination>`.
        :returns: list of instances of :class:`Scenario
            <autostorage.core.scenario.scenario.Scenario>`.
        """
        ids_query = Query(ScenarioRecord.scenario_id)

        if scenarios_filter:
            ids = scenarios_filter.scenario_ids
            if ids:
                ids_query = ids_query.filter(ScenarioRecord.scenario_id.in_(ids))
            else:
                return []

        if pagination:
            offset = pagination.page_index * pagination.items_per_page
            ids_query = ids_query.offset(offset).limit(pagination.items_per_page)

        with self.base.get_session() as session:
            bound_query = ids_query.with_session(session)
            return [ScenarioEntity(self.base, record[0]) for record in bound_query]
Exemplo n.º 3
0
    def get_scenario_nodes(self, nodes_filter=None, pagination=None):
        """Search for scenario nodes by filter.

        :param nodes_filter: instance of :class:`ScenarioNodeFilter
            <autostorage.core.scenario.param_spec.ScenarioNodeFilter>`.
        :param pagination: instance of `Pagination <autostorage.core.param_spec.Pagination>`.
        :returns: generator with instances of :class:`ScenarioNode
            <autostorage.core.scenario.node.ScenarioNode>`.
        """
        ids_query = Query(ScenarioNodeRecord.node_id)

        if nodes_filter:
            ids = nodes_filter.node_ids
            if ids:
                ids_query = ids_query.filter(ScenarioNodeRecord.node_id.in_(ids))
            else:
                return []

        if pagination:
            offset = pagination.page_index * pagination.items_per_page
            ids_query = ids_query.offset(offset).limit(pagination.items_per_page)

        with self.base.get_session() as session:
            bound_query = ids_query.with_session(session)
            return [ScenarioNodeEntity(self.base, record[0]) for record in bound_query]
Exemplo n.º 4
0
    def test_set_object_local_id(self):
        """
            Test the method _set_object_local_id(self, obj, local_id)
            Test scenario:
            Set the local_id of the specified object when the pvc_id is none
        """

        obj_id = self.powerVCMapping.id
        self.powerVCMapping.pvc_id = None
        self.powerVCMapping.local_id = None
        self.powerVCMapping.status = None

        self.aMox.StubOutWithMock(session, 'query')
        session.query(model.PowerVCMapping).AndReturn(query)

        self.aMox.StubOutWithMock(query, 'filter_by')
        query.filter_by(id=obj_id).AndReturn(query)

        self.aMox.StubOutWithMock(query, 'one')
        query.one().AndReturn(self.powerVCMapping)

        self.aMox.StubOutWithMock(session, 'merge')
        session.merge(self.powerVCMapping).AndReturn("")

        self.aMox.ReplayAll()
        self.powervcagentdb._set_object_local_id(self.powerVCMapping, 'test')
        self.aMox.VerifyAll()
        self.assertEqual(self.powerVCMapping.status, 'Creating')
        self.assertEqual(self.powerVCMapping.local_id, 'test')
        self.aMox.UnsetStubs()
Exemplo n.º 5
0
    def test_delete_existing_object(self):
        """
            Test the method _delete_object(self, obj) when the object exists
            Test scenario:
            When the data is in the database, the delete operation should
            complete successfully
        """

        self.aMox.StubOutWithMock(session, 'query')
        session.query(model.PowerVCMapping).AndReturn(query)

        self.aMox.StubOutWithMock(query, 'filter_by')
        query.filter_by(id=self.powerVCMapping['id']).AndReturn(query)

        self.aMox.StubOutWithMock(query, 'one')
        query.one().AndReturn(self.powerVCMapping)

        self.aMox.StubOutWithMock(session, 'begin')
        session.begin(subtransactions=True).AndReturn(transaction(None, None))

        self.aMox.StubOutWithMock(session, 'delete')
        returnValue = session.delete(self.powerVCMapping).AndReturn(True)

        self.aMox.ReplayAll()

        self.powervcagentdb._delete_object(self.powerVCMapping)

        self.aMox.VerifyAll()

        self.assertEqual(returnValue, True)

        self.aMox.UnsetStubs()
Exemplo n.º 6
0
    def _add_join(self, query: Query) -> Query:
        for key in self._foreings:
            planet_table = PLANET_TABLES[key]

            resonance_attr = getattr(self.resonance_cls, "%s_id" % key)
            query = query.outerjoin(planet_table, resonance_attr == planet_table.id)
            if self._load_related:
                options = contains_eager(getattr(self.resonance_cls, key), alias=planet_table)
                query = query.options(options)

        return query
Exemplo n.º 7
0
    def get(self, ident):
        nodeclass = self._find_nodeclass()
        if not nodeclass:
            return Query.get(self, ident)
        else:
            nodeclass = nodeclass[0]
        active_version = Query.get(self, ident)
        Transaction = versioning_manager.transaction_cls
        if active_version is None:
            ver_cls = version_class(nodeclass)
            return (self.session.query(ver_cls).join(Transaction, ver_cls.transaction_id == Transaction.id)
                    .join(Transaction.meta_relation)
                    .filter_by(key=u'alias_id', value=unicode(ident)).scalar())

        return active_version
Exemplo n.º 8
0
 def to_python(conv, value):
     model_ = model or conv.field.form.model
     if value not in ('', None):
         field = getattr(model_, conv.field.name)
         # XXX Using db.query() won't work properly with custom
         # auto-filtering query classes. But silently replacing the class is
         # not good too.
         query = Query(model_, session=conv.env.db)
         if issubclass(model_, WithState):
             states = WithState.PRIVATE, WithState.PUBLIC
             query = query.filter(model_.state.in_(states))
         item = query.filter(field==value).scalar()
         if item is not None and item != conv.field.form.item:
             return False
     return True
Exemplo n.º 9
0
def add_integer_filter(query: Query, ints: List[str], body_tables: List[AliasedClass]) -> Query:
    any_int = '*'

    for integer, table in zip(ints, body_tables):
        if integer != any_int:
            query = query.filter(eval("table.longitude_coeff %s" % integer))
    return query
Exemplo n.º 10
0
    def get_all(instance: Query, offset: int=None, limit: int=None, filters: list=()) -> list:
        """
            Gets all instances of the query instance

            :param instance: sqlalchemy queriable
            :param offset: Offset for request
            :param limit: Limit for request
            :param filters: Filters and OrderBy Clauses
        """
        for expression in filters:
            if _is_ordering_expression(expression):
                instance = instance.order_by(expression)
            else:
                instance = instance.filter(expression)
        if offset is not None:
            instance = instance.offset(offset)
        if limit is not None:
            instance = instance.limit(limit)
        return instance.all()
Exemplo n.º 11
0
    def _apply_kwargs(instance: Query, **kwargs) -> Query:
        for expression in kwargs.pop('filters', []):
            if _is_ordering_expression(expression):
                instance = instance.order_by(expression)
            else:
                instance = instance.filter(expression)

        if 'offset' in kwargs:
            offset = kwargs.pop('offset')
            foffset = lambda instance: instance.offset(offset)
        else:
            foffset = lambda instance: instance

        if 'limit' in kwargs:
            limit = kwargs.pop('limit')
            flimit = lambda instance: instance.limit(limit)
        else:
            flimit = lambda instance: instance

        instance = instance.filter_by(**kwargs)
        instance = foffset(instance)
        instance = flimit(instance)
        return instance
Exemplo n.º 12
0
    def filter_query_for_content_label_as_path(
            self,
            query: Query,
            content_label_as_file: str,
            is_case_sensitive: bool = False,
    ) -> Query:
        """
        Apply normalised filters to found Content corresponding as given label.
        :param query: query to modify
        :param content_label_as_file: label in this
        FILE version, use Content.get_label_as_file().
        :param is_case_sensitive: Take care about case or not
        :return: modified query
        """
        file_name, file_extension = os.path.splitext(content_label_as_file)

        label_filter = Content.label == content_label_as_file
        file_name_filter = Content.label == file_name
        file_extension_filter = Content.file_extension == file_extension

        if not is_case_sensitive:
            label_filter = func.lower(Content.label) == \
                           func.lower(content_label_as_file)
            file_name_filter = func.lower(Content.label) == \
                               func.lower(file_name)
            file_extension_filter = func.lower(Content.file_extension) == \
                                    func.lower(file_extension)

        return query.filter(or_(
            and_(
                Content.type == ContentType.File,
                file_name_filter,
                file_extension_filter,
            ),
            and_(
                Content.type == ContentType.Thread,
                file_name_filter,
                file_extension_filter,
            ),
            and_(
                Content.type == ContentType.Page,
                file_name_filter,
                file_extension_filter,
            ),
            and_(
                Content.type == ContentType.Folder,
                label_filter,
            ),
        ))
Exemplo n.º 13
0
    def test_get_objects_with_status(self):
        """Test the method def _get_objects(self, obj_type, status)
           Test scenario:
           Get the object when the status is not None
        """

        self.aMox.StubOutWithMock(session, 'query')
        session.query(model.PowerVCMapping).AndReturn(query)

        self.aMox.StubOutWithMock(query, 'filter_by')
        query.filter_by(obj_type=self.powerVCMapping.obj_type,
                        status=self.powerVCMapping.status).AndReturn(query)

        self.aMox.StubOutWithMock(query, 'all')
        query.all().AndReturn(self.powerVCMapping)

        self.aMox.ReplayAll()
        returnValue = self.powervcagentdb._get_objects(
            obj_type=self.powerVCMapping.obj_type,
            status=self.powerVCMapping.status)
        self.aMox.VerifyAll()
        self.assertEqual(returnValue, self.powerVCMapping)

        self.aMox.UnsetStubs()
Exemplo n.º 14
0
    def get_scenario_view_states(self, states_filter=None, pagination=None):
        """Search for scenario view_states by filter.

        :param states_filter: instance of :class:`ScenarioStateFilter
            <autostorage.core.scenario.param_spec.ScenarioStateFilter>`.
        :param pagination: instance of `Pagination <autostorage.core.param_spec.Pagination>`.
        :returns: list with instances of :class:`ScenarioState
            <autostorage.core.scenario.scenario.ScenarioState>`.
        """
        ids_query = Query(ScenarioViewStateRecord)

        subquery = Query([
            ScenarioViewStateRecord.scenario_id,
            func.max(ScenarioViewStateRecord.changed).label('newest_change_date')
            ])

        if states_filter and states_filter.date:
            subquery = subquery.filter(ScenarioViewStateRecord.changed <= states_filter.date)

        subquery = subquery.group_by(ScenarioViewStateRecord.scenario_id).subquery()

        ids_query = ids_query.join(
            subquery,
            and_(
                ScenarioViewStateRecord.scenario_id == subquery.columns.scenario_id,
                ScenarioViewStateRecord.changed == subquery.columns.newest_change_date
                )
            )

        if pagination:
            offset = pagination.page_index * pagination.items_per_page
            ids_query = ids_query.offset(offset).limit(pagination.items_per_page)

        with self.base.get_session() as session:
            bound_query = ids_query.with_session(session)
            states = []
            for state_record in bound_query:
                scenario = ScenarioEntity(self.base, state_record.scenario_id)
                state = ScenarioViewStateEntity(
                    scenario=scenario,
                    name=state_record.name,
                    description=state_record.description,
                    date=state_record.changed
                    )
                states.append(state)

        return states
Exemplo n.º 15
0
def test_with_statement(rds_data_client, db_connection):
    with DataAPI(
        database=database,
        resource_arn=resource_arn,
        secret_arn=secret_arn,
        client=rds_data_client,
    ) as data_api:
        insert: Insert = Insert(Pets, {'name': 'dog'})

        result = data_api.execute(insert)
        assert result.number_of_records_updated == 1

        query = Query(Pets).filter(Pets.id == 1)
        result = data_api.execute(query)

        assert list(result) == [Record([1, 'dog'], [])]

        result = data_api.execute('select * from pets')
        assert result.one().dict() == {'id': 1, 'name': 'dog'}

        insert: Insert = Insert(Pets)
        data_api.batch_execute(
            insert,
            [
                {'id': 2, 'name': 'cat'},
                {'id': 3, 'name': 'snake'},
                {'id': 4, 'name': 'rabbit'},
            ],
        )

        result = data_api.execute('select * from pets')
        expected = [
            Record([1, 'dog'], ['id', 'name']),
            Record([2, 'cat'], ['id', 'name']),
            Record([3, 'snake'], ['id', 'name']),
            Record([4, 'rabbit'], ['id', 'name']),
        ]
        assert list(result) == expected

        for row, expected_row in zip(result, expected):
            assert row == expected_row
Exemplo n.º 16
0
    def apply(self, query: Query, value: Any) -> Query:
        filtered_query = query.filter(Database.allow_file_upload)

        datasource_access_databases = can_access_databases("datasource_access")

        if hasattr(g, "user"):
            allowed_schemas = [
                app.config["ALLOWED_USER_CSV_SCHEMA_FUNC"](db, g.user)
                for db in datasource_access_databases
            ]

            if len(allowed_schemas):
                return filtered_query

        return filtered_query.filter(
            or_(
                cast(Database.extra,
                     JSON)["schemas_allowed_for_file_upload"] is not None,
                cast(Database.extra, JSON)["schemas_allowed_for_file_upload"]
                != [],
            ))
Exemplo n.º 17
0
def search_query(quotes: Query, search: QuoteSearch, skip=0, limit=100):
    anywhere = search.anywhere
    title = search.title
    text = search.text
    quote_type = search.quote_type
    description = search.description
    author = search.author
    # public = search.public
    color = search.color
    tags = search.tags

    if title:
        quotes = quotes.filter(Quote.title.ilike(title))
    if text:
        quotes = quotes.filter(Quote.text.ilike(text))
    if quote_type:
        quotes = quotes.filter(Quote.type == quote_type)
    if description:
        quotes = quotes.filter(Quote.description.ilike(description))
    if author:
        quotes = quotes.filter(Quote.author.ilike(author))
    if color:
        quotes = quotes.filter(Quote.color == color.as_hex())

    if anywhere:
        logger.info("ANYWHERE called")
        quotes = quotes.filter(
            or_(Quote.title.ilike(anywhere),
                Quote.text.ilike(anywhere),
                Quote.description.ilike(anywhere),
                Quote.author.ilike(author))
        )
    quotes = quotes.offset(skip).limit(limit).all()

    if tags:
        return get_multi_by_tag(quotes, tags)
    else:
        return quotes
Exemplo n.º 18
0
def test_global_settings_start_time(session):
    if session.bind.name == "postgresql":
        pytest.skip("Can't insert wrong datatype in postgres")
    factories.GlobalSettingsFactory(start_time="18:00:00")
    factories.GlobalSettingsFactory(start_time=None)
    wrong_start_time = factories.GlobalSettingsFactory(
        start_time="asdf18:00:00")

    check_start_time = QueryCheck(
        column=models.GlobalSetting.start_time,
        invalid=Query(models.GlobalSetting).filter(
            func.date(models.GlobalSetting.start_time) == None,
            models.GlobalSetting.start_time != None,
        ),
        message="GlobalSettings.start_time is an invalid, make sure it has the "
        "following format: 'HH:MM:SS'",
    )

    errors = check_start_time.get_invalid(session)
    assert len(errors) == 1
    assert errors[0].id == wrong_start_time.id
Exemplo n.º 19
0
    def bulk_delete_insert(self, delete_query: Query, rows: Iterable[dict], auto_commit=True) -> Tuple[int, int]:
        # TODO: トランザクション操作はおこないわないようにする
        model = self.model

        try:
            deleted = delete_query.delete()
            inserted = 0

            for row in rows:
                obj = model(**row)
                self.db.add(obj)
                inserted += 1

            if auto_commit:
                self.db.commit()

        except Exception as e:
            self.db.rollback()
            raise

        return deleted, inserted
Exemplo n.º 20
0
def apply_states_filters(query: Query, start_day: dt, end_day: dt) -> Query:
    """Filter states by time range.

    Filters states that do not have an old state or new state (added / removed)
    Filters states that are in a continuous domain with a UOM.
    Filters states that do not have matching last_updated and last_changed.
    """
    return (
        query.filter(
            (States.last_updated > start_day) & (States.last_updated < end_day)
        )
        .outerjoin(OLD_STATE, (States.old_state_id == OLD_STATE.state_id))
        .where(_missing_state_matcher())
        .where(_not_continuous_entity_matcher())
        .where(
            (States.last_updated == States.last_changed) | States.last_changed.is_(None)
        )
        .outerjoin(
            StateAttributes, (States.attributes_id == StateAttributes.attributes_id)
        )
    )
Exemplo n.º 21
0
def edit(
    id: str,
    db_engine: Engine,
    title: str = None,
    urgency: int = None,
    importance: int = None,
    tags: List[str] = None,
    anchor_folder: str = None,
    description: str = None,
):
    """Edit a task

    :param id: ID of the task to edit.
    :param db_engine: Engine for the tasks database.
    :param title: Update title of the task.
    :param urgency: Update urgency level[0-4] of the task.
    :param importance: Update importance level[0-4] of the task.
    :param tags: Set of tags to apply to the new task.
    :param anchor_folder: Anchor this task to a particular directory or file.
    :param description: Description of the task.
    """
    task: db.Task
    with db.session_scope(db_engine) as session:
        task = Query(db.Task, session).filter_by(id=id).one()
        if title:
            task.title = title
        if urgency:
            task.urgency = urgency
        if importance:
            task.importance = importance
        if tags:
            task.tags = tags
        if anchor_folder:
            task.folder = anchor_folder
        if description:
            task.description = description
        session.add(task)
def delete_test_data(db_query: Query, retry: bool = True) -> int:
    """Delete database objects based on user query. Intended to be used only for integration tests cleanup.

    Args:
        db_query: sqlalchemy Query object that defines the rows to delete
        retry: Attempt retry in case of deadlock errors due to parallel runs of tests

    Returns:
        Number of deleted rows.

    """
    try:
        with Database(DatabaseType.internal).transaction_context() as session:
            return int(
                db_query.with_session(session).delete(
                    synchronize_session=False))  # type: ignore
    except DBAPIError as e:  # pragma: no cover
        if retry:  # Retry in case of deadlock error from the database when running tests in parallel
            logger.warning(f"Retrying error: {e}")
            return delete_test_data(db_query, retry=False)
        else:
            raise
Exemplo n.º 23
0
    def __init__(self,
                 model,
                 query=None,
                 _as_relation=None,
                 join_path=None,
                 aliased=None):
        """ Init a MongoDB-style query
        :param model: MongoModel
        :type model: mongosql.MongoModel
        :param query: Query to work with
        :type query: sqlalchemy.orm.Query
        :param _as_relation: Parent relationship.
            Internal argument used when working with deeper relations:
            is used as initial path for defaultload(_as_relation).lazyload(...).
        :type _as_relation: sqlalchemy.orm.relationships.RelationshipProperty
        """
        if query is None:
            query = Query([model.model])
        self.join_hook = None
        self._model = model
        # This magic is here because if we use alias as target_model it
        # somehow override the Model class. So tests fails if run all tests,
        # but if you run single test - it pass. So we create MongoModel for alias
        # here instead of "get_for".
        if aliased:
            self._model = MongoModel(aliased)
        self._query = query
        self.join_path = join_path or ()

        if join_path:
            self._as_relation = defaultload(*join_path)
        else:
            self._as_relation = defaultload(
                _as_relation) if _as_relation else Load(self._model.model)
        self.join_queries = []
        self.skip_or_limit = False
        self._order_by = None
        self._project = {}
        self._end_query = None
Exemplo n.º 24
0
    def get_structure(self, structure_filter=None):
        """Get structure of scenario by filter.

        :param structure_filter: instance of :class:`ScenarioStructureFilter
            <autostorage.core.scenario.param_spec.ScenarioStructureFilter>`.
        :returns: dictionary with tree_path as keys and node_ids as values.
        """
        subquery = Query([
            ScenarioStructureStateRecord.scenario_id,
            ScenarioStructureStateRecord.tree_path,
            func.max(ScenarioStructureStateRecord.changed).label('newest_change_date')
            ]).filter_by(
                scenario_id=self.__id
            )

        if structure_filter:
            if structure_filter.date:
                subquery = subquery.filter(
                    ScenarioStructureStateRecord.changed <= structure_filter.date)

            if structure_filter.tree_path:
                subquery = subquery.filter(ScenarioStructureStateRecord.tree_path.like(
                    "{0}-%".format(structure_filter.tree_path)
                    ))

        subquery = subquery.group_by(
            ScenarioStructureStateRecord.scenario_id,
            ScenarioStructureStateRecord.tree_path
            ).subquery()

        states_query = Query([
            ScenarioStructureStateRecord
            ]).join(
                subquery,
                and_(
                    ScenarioStructureStateRecord.scenario_id == subquery.columns.scenario_id,
                    ScenarioStructureStateRecord.tree_path == subquery.columns.tree_path,
                    ScenarioStructureStateRecord.changed == subquery.columns.newest_change_date
                )
            ).filter(
                ScenarioStructureStateRecord.enabled == True  # pylint: disable=singleton-comparison
            )

        with self.base.get_session() as session:
            bound_query = states_query.with_session(session)
            return {record.tree_path: record.node_id for record in bound_query}
Exemplo n.º 25
0
    def test_limit_with_filtered_join(self):
        m = models.User
        mq = m.mongoquery(Query([models.User]))
        mq = mq.query(
            limit=10,
            join={'articles': {
                'filter': {
                    'title': {
                        '$exists': True
                    }
                }
            }})
        q = mq.end()
        qs = q2sql(q)
        self._check_qs(
            """SELECT anon_1.u_id AS anon_1_u_id, anon_1.u_name AS anon_1_u_name, anon_1.u_tags AS anon_1_u_tags, anon_1.u_age AS anon_1_u_age, a_1.id AS a_1_id, a_1.uid AS a_1_uid, a_1.title AS a_1_title, a_1.theme AS a_1_theme, a_1.data AS a_1_data
FROM (SELECT u.id AS u_id, u.name AS u_name, u.tags AS u_tags, u.age AS u_age
FROM u
WHERE EXISTS (SELECT 1
FROM a
WHERE u.id = a.uid AND a.title IS NOT NULL)
 LIMIT 10) AS anon_1 JOIN a AS a_1 ON anon_1.u_id = a_1.uid
WHERE a_1.title IS NOT NULL""", qs)
Exemplo n.º 26
0
def potentially_limit_query_to_account_assets(
        query: Query, account_id: Optional[int]) -> Query:
    """Filter out all assets that are not in the current user's account.
    For admins and CLI users, no assets are filtered out, unless an account_id is set.

    :param account_id: if set, all assets that are not in the given account will be filtered out (only works for admins and CLI users). For querying public assets in particular, don't use this function.
    """
    if not running_as_cli() and not current_user.is_authenticated:
        raise Forbidden("Unauthenticated user cannot list assets.")
    user_is_admin = (running_as_cli() or current_user.has_role(ADMIN_ROLE)
                     or current_user.has_role(ADMIN_READER_ROLE))
    if account_id is None and user_is_admin:
        return query  # allow admins to query assets across all accounts
    if (account_id is not None and account_id != current_user.account_id
            and not user_is_admin):
        raise Forbidden("Non-admin cannot access assets from other accounts.")
    account_id_to_filter = (account_id if account_id is not None else
                            current_user.account_id)
    return query.filter(
        or_(
            GenericAsset.account_id == account_id_to_filter,
            GenericAsset.account_id == null(),
        ))
Exemplo n.º 27
0
    def query_constructor_filter_specifiable(
            self,
            transfer: CreditTransfer,
            base_query: Query,
            custom_filter: AggregationFilter) -> Query:
        """
        Constructs a filtered query for aggregation, where the last filter step can be provided by the user
        :param transfer:
        :param base_query:
        :param custom_filter:
        :return: An SQLAlchemy Query Object
        """

        filter_list = combine_filter_lists(
            [
                matching_sender_user_filter(transfer),
                not_rejected_filter(),
                after_time_period_filter(self.time_period_days),
                custom_filter(transfer)
            ]
        )

        return base_query.filter(*filter_list)
Exemplo n.º 28
0
def filter_str(
    query: Query,
    column: Column,
    negate: bool = False,
    oper: str = "eq",
    value: Optional[str] = None,
) -> Query:
    """ Update the query to filter based on string comparisons for a column """
    if value is None:
        return query

    if oper == "eq":
        expr = column == value
    elif oper == "starts":
        expr = column.startswith(value)
    elif oper == "ends":
        expr = column.endswith(value)
    elif oper == "contains":
        expr = column.contains(value)

    if negate:
        expr = not_(expr)
    return query.filter(expr)
Exemplo n.º 29
0
def common_filter(query: Query, data_schema, start_timestamp=None, end_timestamp=None,
                  filters=None, order=None, limit=None):
    if start_timestamp:
        query = query.filter(data_schema.timestamp >= to_pd_timestamp(start_timestamp))
    if end_timestamp:
        query = query.filter(data_schema.timestamp <= to_pd_timestamp(end_timestamp))

    if filters:
        for filter in filters:
            query = query.filter(filter)
    if order is not None:
        query = query.order_by(order)
    else:
        query = query.order_by(data_schema.timestamp.asc())
    if limit:
        query = query.limit(limit)

    return query
Exemplo n.º 30
0
    def _filter_query_for_text_contents(
            self, q: Query, taskclass: Type[Task]) -> Optional[Query]:
        """
        Returns the query, filtered for the "text contents" filter.

        Args:
            q: the starting SQLAlchemy ORM Query
            taskclass: the task class

        Returns:
            a Query, potentially modified.
        """
        tf = self._filter  # task filter

        if not tf.text_contents:
            return q  # unmodified

        # task must contain ALL the strings in AT LEAST ONE text column
        textcols = taskclass.get_text_filter_columns()
        if not textcols:
            # Text filtering requested, but there are no text columns, so
            # by definition the filter must fail.
            return None
        clauses_over_text_phrases = []  # type: List[ColumnElement]
        # ... each e.g. "col1 LIKE '%paracetamol%' OR col2 LIKE '%paracetamol%'"  # noqa
        for textfilter in tf.text_contents:
            tf_lower = textfilter.lower()
            clauses_over_columns = []  # type: List[ColumnElement]
            # ... each e.g. "col1 LIKE '%paracetamol%'"
            for textcol in textcols:
                # Case-insensitive comparison:
                # https://groups.google.com/forum/#!topic/sqlalchemy/331XoToT4lk
                # https://bitbucket.org/zzzeek/sqlalchemy/wiki/UsageRecipes/StringComparisonFilter  # noqa
                clauses_over_columns.append(
                    func.lower(textcol).contains(tf_lower, autoescape=True))
            clauses_over_text_phrases.append(or_(*clauses_over_columns))
        return q.filter(and_(*clauses_over_text_phrases))
Exemplo n.º 31
0
def _apply_devices_context_union(
    query: Query,
    start_day: dt,
    end_day: dt,
    event_types: tuple[str, ...],
    json_quoted_device_ids: list[str],
) -> CompoundSelect:
    """Generate a CTE to find the device context ids and a query to find linked row."""
    devices_cte: CTE = _select_device_id_context_ids_sub_query(
        start_day,
        end_day,
        event_types,
        json_quoted_device_ids,
    ).cte()
    return query.union_all(
        apply_events_context_hints(
            select_events_context_only().select_from(devices_cte).outerjoin(
                Events,
                devices_cte.c.context_id == Events.context_id)).outerjoin(
                    EventData, (Events.data_id == EventData.data_id)),
        apply_states_context_hints(
            select_states_context_only().select_from(devices_cte).outerjoin(
                States, devices_cte.c.context_id == States.context_id)),
    )
Exemplo n.º 32
0
 def join_to_clusters(base_citation_query: Query) -> Tuple[Query, Alias, Alias]:
     citing_opinion, cited_opinion = aliased(Opinion), aliased(Opinion)
     citing_cluster, cited_cluster = aliased(Cluster), aliased(Cluster)
     return (
         (
             base_citation_query.join(
                 citing_opinion,
                 Citation.citing_opinion_id == citing_opinion.resource_id,
             )
             .join(
                 cited_opinion,
                 Citation.cited_opinion_id == cited_opinion.resource_id,
             )
             .join(
                 citing_cluster,
                 citing_opinion.cluster_id == citing_cluster.resource_id,
             )
             .join(
                 cited_cluster, cited_opinion.cluster_id == cited_cluster.resource_id
             )
         ),
         citing_cluster,
         cited_cluster,
     )
Exemplo n.º 33
0
 def _filter_unprocessed(
     self,
     env: Environment,
     query: Select,
 ) -> Select:
     if not self._filters.unprocessed_by_node_key:
         return query
     if self._filters.allow_cycle:
         # Only exclude blocks processed as INPUT
         filter_clause = and_(
             DataBlockLog.direction == Direction.INPUT,
             SnapLog.node_key == self._filters.unprocessed_by_node_key,
         )
     else:
         # No block cycles allowed
         # Exclude blocks processed as INPUT and blocks outputted
         filter_clause = SnapLog.node_key == self._filters.unprocessed_by_node_key
     already_processed_drs = (
         Query(DataBlockLog.data_block_id).join(SnapLog).filter(
             filter_clause).filter(
                 DataBlockLog.invalidated == False)  # noqa
         .distinct())
     return query.filter(
         not_(DataBlockMetadata.id.in_(already_processed_drs)))
Exemplo n.º 34
0
    def __init__(self, query: Query, page_id=1, page_size=20, page_info=None):
        if page_info:
            page_id = page_info.page_id
            page_size = page_info.page_size

        if page_id < 1 or page_size < 1:
            logger.info(
                f'page_id: {page_id}, page_size: {page_size}, page_info: {page_info}'
            )
            raise NotFound()

        total = fast_count(query)
        if total and total > 0:
            items = query.limit(page_size).offset(
                (page_id - 1) * page_size).all()
        else:
            total = 0
            page_id = 1
            items = []

        self.items = items
        self.page_id = page_id
        self.page_size = page_size
        self.total = total
Exemplo n.º 35
0
def add_user_source_filter(
    cls: "ts.TimedValue", query: Query, user_source_ids: Union[int, List[int]]
) -> Query:
    """Add filter to the query to search only through user data from the specified user sources.

    We distinguish user sources (sources with source.type == "user") from other sources (source.type != "user").
    Data with a user source originates from a registered user. Data with e.g. a script source originates from a script.

    This filter doesn't affect the query over non-user type sources.
    It does so by ignoring user sources that are not in the given list of source_ids.
    """
    if user_source_ids is not None and not isinstance(user_source_ids, list):
        user_source_ids = [user_source_ids]  # ensure user_source_ids is a list
    if user_source_ids:
        ignorable_user_sources = (
            DataSource.query.filter(DataSource.type == "user")
            .filter(DataSource.id.notin_(user_source_ids))
            .all()
        )
        ignorable_user_source_ids = [
            user_source.id for user_source in ignorable_user_sources
        ]
        query = query.filter(cls.data_source_id.notin_(ignorable_user_source_ids))
    return query
Exemplo n.º 36
0
def tobs():
    session = Session(engine)

    first_date = dt.datetime.strptime(
        (Query(measurement).with_session(session).order_by(
            measurement.date.desc()).first().date), "%Y-%m-%d")

    year_before = first_date - dt.timedelta(days=365)

    most_active_list = session.query(
        measurement.station, func.count(measurement.station)).group_by(
            measurement.station).order_by(
                func.count(measurement.station).desc()).all()
    most_active = most_active_list[0][0]

    results = session.query(measurement.date, measurement.tobs).filter(
        measurement.station == most_active,
        measurement.date > year_before).all()

    session.close()

    results = list(results)

    return (jsonify(results))
Exemplo n.º 37
0
def paginated_update(
    query: Query,
    print_page_progress: Optional[Union[Callable[[int, int], None],
                                        bool]] = None,
    batch_size: int = DEFAULT_BATCH_SIZE,
) -> Iterator[Any]:
    """
    Update models in small batches so we don't have to load everything in memory.
    """
    start = 0
    count = query.count()
    session: Session = inspect(query).session
    if print_page_progress is None or print_page_progress is True:
        print_page_progress = lambda current, total: print(
            f"    {current}/{total}", end="\r")
    while start < count:
        end = min(start + batch_size, count)
        for obj in query[start:end]:
            yield obj
            session.merge(obj)
        session.commit()
        if print_page_progress:
            print_page_progress(end, count)
        start += batch_size
Exemplo n.º 38
0
def common_filter(
    query: Query,
    data_schema,
    start_timestamp=None,
    end_timestamp=None,
    filters=None,
    order=None,
    limit=None,
    time_field="timestamp",
):
    """
    build filter by the arguments

    :param query: sql query
    :param data_schema: data schema
    :param start_timestamp: start timestamp
    :param end_timestamp: end timestamp
    :param filters: sql filters
    :param order: sql order
    :param limit: sql limit size
    :param time_field: time field in columns
    :return: result query
    """
    assert data_schema is not None
    time_col = eval("data_schema.{}".format(time_field))

    if start_timestamp:
        query = query.filter(time_col >= to_pd_timestamp(start_timestamp))
    if end_timestamp:
        query = query.filter(time_col <= to_pd_timestamp(end_timestamp))

    if filters:
        for filter in filters:
            query = query.filter(filter)
    if order is not None:
        query = query.order_by(order)
    else:
        query = query.order_by(time_col.asc())
    if limit:
        query = query.limit(limit)

    return query
Exemplo n.º 39
0
    def test_3_queries(self):
        session = self.get_session()
        p = session.query(User).first()
        print(f"{p.name}")
        all_rec = [rec for rec in session.query(Channel).all()]
        self.assertEqual(len(all_rec), 3)
        #
        # Check the Join table is there
        #
        all_user = [rec for rec in session.query(UserChannel).all()]
        self.assertEqual(len(all_user), 3)

        # Query with no Join ... Bad result
        q = Query([User, Channel, UserChannel], session=session)
        self.assertEqual(q.count(), 18)

        q = Query([User, Channel, UserChannel], session=session). \
            filter(User.user_id == UserChannel.user_id). \
            filter(Channel.channel_id == UserChannel.channel_id)
        self.assertEqual(q.count(), 3)
Exemplo n.º 40
0
 def basic_query(session, block):
     query = Query([Flight], session)
     query = block(query)
     return query.filter(Flight.f_from > datetime.datetime.now() + datetime.timedelta(days=1)) \
         .order_by(Flight.f_price) \
         .all()
Exemplo n.º 41
0
 def already_exists(self, session) -> bool:
     return Query(func.count(Flight.id),
                  session=session).filter(Flight.f_id == self.f_id).filter(
                      Flight.f_price <= self.f_price).first()[0] == 1
Exemplo n.º 42
0
def delete_customizations_by_order_id(order_id: UUID):
    with db_session() as session:
        Query(OrderCustomization,
              session=session).filter_by(order_id=order_id).delete()
Exemplo n.º 43
0
 def order_query(self, query: Query):
     """Sort the query."""
     return query.order_by(Asset.name)
Exemplo n.º 44
0
###############################################################################
#                         ORM-level SQL construction                                   
#     Query is the source of all SELECT statements generated by the ORM.  It
# features a generative interface whereby successive calls return a new Query 
# object, a copy of the former with additional criteria and options associated
# with it.
#     Query objects are normally initially generated using the query() method
# of Session. 
#     query() takes a variable number of arguments, any combination of mapped  
# class, a Mapper object, an orm-enabled descriptor, or an AliasedClass object.
###############################################################################

# basic query #############################################

# select
query = Query(User)

# insert
session.add(User(id=1))

# aggregation
(Query(User.department.
       func.sum(User.salary).albel('salary'),
       func.count('*').label('total_number'))
 .group_by(User.department))

# case
Query(User.id,
      case([(User.salary < 1000, 1),
            (User.salary.between(1000,2000), 2)],
           else_=0).label('salary'))
Exemplo n.º 45
0
def apply_entities_hints(query: Query) -> Query:
    """Force mysql to use the right index on large selects."""
    return query.with_hint(States,
                           f"FORCE INDEX ({ENTITY_ID_LAST_UPDATED_INDEX})",
                           dialect_name="mysql")
Exemplo n.º 46
0
def get_size_by_id(size_id: UUID) -> Union[Dict, None]:
    with db_session() as session:
        size: Size = Query(Size, session=session).filter_by(id=size_id).one()
        if not size:
            return None
        return size.to_dict()
Exemplo n.º 47
0
from sqlalchemy import Column, String, case, func
from sqlalchemy.orm import Query
from sqlalchemy.sql.functions import coalesce
from util import Base, session


class PrinterControl(Base):
    __tablename__ = "printercontrol"

    user_id = Column(String(10))
    printer_name = Column(String(4), nullable=False, primary_key=True)
    printer_description = Column(String(40), nullable=False)

user_name = "leea"
user_printer = Query([PrinterControl]).filter(PrinterControl.user_id == user_name)
s = session.query(PrinterControl).filter(case(
    [(user_printer.exists(), PrinterControl.user_id == user_name)],
    else_=(PrinterControl.user_id == None))
)

[print(s.printer_name) for s in s]


# 集計関数がNULLを返すことを利用する
anonymous_printer = Query([func.min(PrinterControl.printer_name)])\
    .filter(PrinterControl.user_id == None).as_scalar()

s = session.query(coalesce(func.min(PrinterControl.printer_name), anonymous_printer))\
    .filter(PrinterControl.user_id == user_name)
Exemplo n.º 48
0
def get_order(order_id: UUID) -> Dict:
    with db_session() as session:
        order: Order = Query(Order,
                             session=session).filter_by(id=order_id).one()
        return order.to_dict()
Exemplo n.º 49
0
 def get_count(self, query:Query):
     """Calculate total item count based on query."""
     return query.count()