Ejemplo n.º 1
0
def test_group_concat_sqlite_one_arg(db_session):
    """
    It should use SQLite's deafult arguments (comma delimiter)
    """
    from sqlalchemy import literal_column
    from occams_datastore.utils.sql import group_concat

    if db_session.bind.url.drivername != 'sqlite':
        pytest.skip('Not using SQLite')

    data = (
        db_session.query(
            literal_column("'myitem'").label('name'),
            literal_column("'foo'").label('value'))
        .union(
            db_session.query(
                literal_column("'myitem'").label('name'),
                literal_column("'bar'").label('value')))
        .subquery())

    query = (
        db_session.query(group_concat(data.c.value))
        .select_from(data)
        .group_by(data.c.name))

    result, = query.one()
    assert sorted(['foo', 'bar']) == sorted(result.split(','))
Ejemplo n.º 2
0
    def test_tuple_containment(self):

        for test, exp in [
            ([("a", "b")], True),
            ([("a", "c")], False),
            ([("f", "q"), ("a", "b")], True),
            ([("f", "q"), ("a", "c")], False),
        ]:
            eq_(
                testing.db.execute(
                    select(
                        [
                            tuple_(
                                literal_column("'a'"), literal_column("'b'")
                            ).in_(
                                [
                                    tuple_(
                                        *[
                                            literal_column("'%s'" % letter)
                                            for letter in elem
                                        ]
                                    )
                                    for elem in test
                                ]
                            )
                        ]
                    )
                ).scalar(),
                exp,
            )
Ejemplo n.º 3
0
    def test_tuple_containment(self):

        for test, exp in [
            ([('a', 'b')], True),
            ([('a', 'c')], False),
            ([('f', 'q'), ('a', 'b')], True),
            ([('f', 'q'), ('a', 'c')], False)
        ]:
            eq_(
                testing.db.execute(
                    select([
                        tuple_(
                            literal_column("'a'"),
                            literal_column("'b'")
                        ).
                        in_([
                            tuple_(*[
                                literal_column("'%s'" % letter)
                                for letter in elem
                            ]) for elem in test
                        ])
                    ])
                ).scalar(),
                exp
            )
Ejemplo n.º 4
0
    def test_row_case_sensitive_unoptimized(self):
        ins_db = engines.testing_engine(options={"case_sensitive": True})
        row = ins_db.execute(
            select([
                literal_column("1").label("case_insensitive"),
                literal_column("2").label("CaseSensitive"),
                text("3 AS screw_up_the_cols")
            ])
        ).first()

        eq_(
            list(row.keys()),
            ["case_insensitive", "CaseSensitive", "screw_up_the_cols"])

        in_("case_insensitive", row._keymap)
        in_("CaseSensitive", row._keymap)
        not_in_("casesensitive", row._keymap)

        eq_(row["case_insensitive"], 1)
        eq_(row["CaseSensitive"], 2)
        eq_(row["screw_up_the_cols"], 3)

        assert_raises(KeyError, lambda: row["Case_insensitive"])
        assert_raises(KeyError, lambda: row["casesensitive"])
        assert_raises(KeyError, lambda: row["screw_UP_the_cols"])
Ejemplo n.º 5
0
def test_group_concat_postgresql_invalid_args(db_session):
    """
    It should only support at least two arguments in PostgreSQL
    """
    from sqlalchemy import literal_column
    from occams_datastore.utils.sql import group_concat

    if db_session.bind.url.drivername != 'postgresql':
        pytest.skip('Not using PostgreSQL')

    data = (
        db_session.query(
            literal_column("'myitem'").label('name'),
            literal_column("'foo'").label('value'))
        .union(
            db_session.query(
                literal_column("'myitem'").label('name'),
                literal_column("'bar'").label('value')))
        .subquery())

    query = (
        db_session.query(group_concat(data.c.value))
        .select_from(data)
        .group_by(data.c.name))

    with pytest.raises(TypeError):
        result, = query.one()
Ejemplo n.º 6
0
    def test_percent_sign_round_trip(self):
        """test that the DBAPI accommodates for escaped / nonescaped
        percent signs in a way that matches the compiler

        """
        m = self.metadata
        t = Table('t', m, Column('data', String(50)))
        t.create(config.db)
        with config.db.begin() as conn:
            conn.execute(t.insert(), dict(data="some % value"))
            conn.execute(t.insert(), dict(data="some %% other value"))

            eq_(
                conn.scalar(
                    select([t.c.data]).where(
                        t.c.data == literal_column("'some % value'"))
                ),
                "some % value"
            )

            eq_(
                conn.scalar(
                    select([t.c.data]).where(
                        t.c.data == literal_column("'some %% other value'"))
                ), "some %% other value"
            )
Ejemplo n.º 7
0
    def test_row_case_sensitive(self):
        row = testing.db.execute(
            select([
                literal_column("1").label("case_insensitive"),
                literal_column("2").label("CaseSensitive")
            ])
        ).first()

        eq_(list(row.keys()), ["case_insensitive", "CaseSensitive"])

        in_("case_insensitive", row._keymap)
        in_("CaseSensitive", row._keymap)
        not_in_("casesensitive", row._keymap)

        eq_(row["case_insensitive"], 1)
        eq_(row["CaseSensitive"], 2)

        assert_raises(
            KeyError,
            lambda: row["Case_insensitive"]
        )
        assert_raises(
            KeyError,
            lambda: row["casesensitive"]
        )
Ejemplo n.º 8
0
    def test_text_doesnt_explode(self):

        for s in [
            select(
                [
                    case(
                        [
                            (
                                info_table.c.info == 'pk_4_data',
                                text("'yes'"))],
                        else_=text("'no'"))
                ]).order_by(info_table.c.info),

            select(
                [
                    case(
                        [
                            (
                                info_table.c.info == 'pk_4_data',
                                literal_column("'yes'"))],
                        else_=literal_column("'no'")
                    )]
            ).order_by(info_table.c.info),

        ]:
            if testing.against("firebird"):
                eq_(s.execute().fetchall(), [
                    ('no ', ), ('no ', ), ('no ', ), ('yes', ),
                    ('no ', ), ('no ', ),
                ])
            else:
                eq_(s.execute().fetchall(), [
                    ('no', ), ('no', ), ('no', ), ('yes', ),
                    ('no', ), ('no', ),
                ])
Ejemplo n.º 9
0
    def execute(self, connection, filter_values):
        max_date_query = sqlalchemy.select([
            sqlalchemy.func.max(sqlalchemy.column('completed_on')).label('completed_on'),
            sqlalchemy.column('case_id').label('case_id')
        ]).select_from(sqlalchemy.table(self.table_name))

        if self.filters:
            for filter in self.filters:
                max_date_query.append_whereclause(filter.build_expression())

        max_date_query.append_group_by(
            sqlalchemy.column('case_id')
        )

        max_date_subquery = sqlalchemy.alias(max_date_query, 'max_date')

        asha_table = self.get_asha_table_name()
        checklist_query = sqlalchemy.select()
        for column in self.columns:
            checklist_query.append_column(column.build_column())

        checklist_query = checklist_query.where(
            sqlalchemy.literal_column('"{}".case_id'.format(asha_table)) == max_date_subquery.c.case_id
        ).where(
            sqlalchemy.literal_column('"{}".completed_on'.format(asha_table)) == max_date_subquery.c.completed_on
        ).select_from(sqlalchemy.table(asha_table))

        return connection.execute(checklist_query, **filter_values).fetchall()
Ejemplo n.º 10
0
 def on_insert_also(self):
     Foo2 = sql.t.Foo2
     n_column = sqlalchemy.literal_column('1').label('n')
     return (Foo2.insert().values(n=sqlalchemy.literal_column('new.id'),
                                  foo=sqlalchemy.literal_column('new.description')),
             sql.InsertFromSelect(Foo2, sqlalchemy.select([n_column])),
             "select 42",
             )
Ejemplo n.º 11
0
 def test_select_composition_seven(self):
     self.assert_compile(
         select([
             literal_column('col1'),
             literal_column('col2')
         ], from_obj=table('tablename')).alias('myalias'),
         "SELECT col1, col2 FROM tablename"
     )
Ejemplo n.º 12
0
 def test_select_composition_seven(self):
     self.assert_compile(
         select(
             [literal_column("col1"), literal_column("col2")],
             from_obj=table("tablename"),
         ).alias("myalias"),
         "SELECT col1, col2 FROM tablename",
     )
Ejemplo n.º 13
0
 def _update_current_rev(self, old, new):
     if old == new:
         return
     if new is None:
         self.impl._exec(self._version.delete())
     elif old is None:
         self.impl._exec(self._version.insert().values(version_num=literal_column("'%s'" % new)))
     else:
         self.impl._exec(self._version.update().values(version_num=literal_column("'%s'" % new)))
Ejemplo n.º 14
0
    def data(self,
             use_choice_labels=False,
             expand_collections=False,
             ignore_private=True):
        session = self.db_session
        query = (
            session.query(
                models.Patient.id.label('id'),
                models.Site.name.label('site'),
                models.Patient.pid.label('pid'))
            .join(models.Site))

        # BBB 2014-02-20 (Marco): AEH needs Early Test
        EarlyTest = aliased(models.Enrollment)
        subquery = (
            session.query(EarlyTest.patient_id, EarlyTest.reference_number)
            .filter(EarlyTest.study.has(
                models.Study.code.in_([literal_column("'ET'"),
                                       literal_column("'LTW'"),
                                       literal_column("'CVCT'")])))
            .subquery())
        query = (
            query
            .outerjoin(subquery, subquery.c.patient_id == models.Patient.id)
            .add_column(subquery.c.reference_number.label('early_id')))

        # Add every known reference number
        for reftype in self.reftypes:
            query = query.add_column(
                session.query(
                    group_concat(
                        models.PatientReference.reference_number, ';'))
                .filter(
                    models.PatientReference.patient_id == models.Patient.id)
                .filter(
                    models.PatientReference.reference_type_id == reftype.id)
                .group_by(models.PatientReference.patient_id)
                .correlate(models.Patient)
                .as_scalar()
                .label(reftype.name))

        CreateUser = aliased(datastore.User)
        ModifyUser = aliased(datastore.User)

        query = (
            query
            .join(CreateUser, models.Patient.create_user)
            .join(ModifyUser, models.Patient.modify_user)
            .add_columns(
                models.Patient.create_date,
                CreateUser.key.label('create_user'),
                models.Patient.modify_date,
                ModifyUser.key.label('modify_user'))
            .order_by(models.Patient.id))

        return query
Ejemplo n.º 15
0
def search_query(cls, tokens, weight_func=None, include_misses=False, ordered=True):

    # Read the searchable columns from the table (strings)
    columns = cls.__searchable_columns__

    # Convert the columns from strings into column objects
    columns = [getattr(cls, c) for c in columns]

    # The model name that can be used to match search result to model
    cls_name = literal_column("'{}'".format(cls.__name__))

    # Filter out id: tokens for later
    ids, tokens = process_id_option(tokens)

    # If there are still tokens left after id: token filtering
    if tokens:
        # Generate the search weight expression from the
        # searchable columns, tokens and patterns
        if not weight_func:
            weight_func = weight_expression

        weight = weight_func(columns, tokens)

    # If the search expression only included "special" tokens like id:
    else:
        weight = literal_column(str(1))

    # Create an array of stringified detail columns
    details = getattr(cls, "__search_detail_columns__", None)
    if details:
        details = [cast(getattr(cls, d), Unicode) for d in details]
    else:
        details = [literal_column("NULL")]

    # Create a query object
    query = db.session.query(
        cls_name.label("model"),
        cls.id.label("id"),
        cls.name.label("name"),
        array(details).label("details"),
        weight.label("weight"),
    )

    # Filter out specific ids (optional)
    if ids:
        query = query.filter(cls.id.in_(ids))

    # Filter out results that don't match the patterns at all (optional)
    if not include_misses:
        query = query.filter(weight > 0)

    # Order by weight (optional)
    if ordered:
        query = query.order_by(desc(weight))

    return query
Ejemplo n.º 16
0
    def timeseries(self, agg_unit, start, end, geom=None, column_filters=None):
        # Reading this blog post
        # http://no0p.github.io/postgresql/2014/05/08/timeseries-tips-pg.html
        # inspired this implementation.
        t = self.point_table

        # Special case for the 'quarter' unit of aggregation.
        step = '3 months' if agg_unit == 'quarter' else '1 ' + agg_unit

        # Create a CTE to represent every time bucket in the timeseries
        # with a default count of 0
        day_generator = func.generate_series(func.date_trunc(agg_unit, start),
                                             func.date_trunc(agg_unit, end),
                                             step)
        defaults = select([sa.literal_column("0").label('count'),
                           day_generator.label('time_bucket')])\
            .alias('defaults')

        where_filters = [t.c.point_date >= start, t.c.point_date <= end]
        if column_filters is not None:
            # Column filters has to be iterable here, because the '+' operator
            # behaves differently for SQLAlchemy conditions. Instead of
            # combining the conditions together, it would try to build
            # something like :param1 + <column_filters> as a new condition.
            where_filters += [column_filters]

        # Create a CTE that grabs the number of records contained in each time
        # bucket. Will only have rows for buckets with records.
        actuals = select([func.count(t.c.hash).label('count'),
                          func.date_trunc(agg_unit, t.c.point_date).
                         label('time_bucket')])\
            .where(sa.and_(*where_filters))\
            .group_by('time_bucket')

        # Also filter by geometry if requested
        if geom:
            contains = func.ST_Within(t.c.geom, func.ST_GeomFromGeoJSON(geom))
            actuals = actuals.where(contains)

        # Need to alias to make it usable in a subexpression
        actuals = actuals.alias('actuals')

        # Outer join the default and observed values
        # to create the timeseries select statement.
        # If no observed value in a bucket, use the default.
        name = sa.literal_column("'{}'".format(self.dataset_name))\
            .label('dataset_name')
        bucket = defaults.c.time_bucket.label('time_bucket')
        count = func.coalesce(actuals.c.count, defaults.c.count).label('count')
        ts = select([name, bucket, count]).\
            select_from(defaults.outerjoin(actuals, actuals.c.time_bucket == defaults.c.time_bucket))

        return ts
Ejemplo n.º 17
0
def get_next_bus(mc, db, stop_id):
    now_datetime = datetime.datetime.utcnow().replace(tzinfo=UTC)
    now_datetime = LONDON.normalize(now_datetime.astimezone(LONDON))
    now_day = now_datetime.weekday()
    now_time = now_datetime.time()

    mc_key = "V6:USERSTOP:" +  str(stop_id) + "USERTIME:" + now_datetime.strftime("%w%H%M")

    bus = mc.get(mc_key)
    if not bus:
        today_query = db.query(DepartureTimeDeref, literal_column("0").label("days_future")).\
                                      filter_by(bus_stop_id=stop_id).\
                                      filter(DepartureTimeDeref.time >= now_time).\
                                      filter(DepartureTimeDeref.valid_days.contains(cast([DAYS[now_day]], postgresql.ARRAY(String)))).\
                                      join(DepartureTimeDeref.timetable).\
                                      filter(Timetable.valid_from <= now_datetime.date()).\
                                      filter(Timetable.valid_to >= now_datetime.date())

        single_day_delta = datetime.timedelta(days=1)
        tomorrow_query = db.query(DepartureTimeDeref, literal_column("1").label("days_future")).\
                                      filter_by(bus_stop_id=stop_id).\
                                      filter(DepartureTimeDeref.valid_days.contains(cast([DAYS[now_day + 1]], postgresql.ARRAY(String)))).\
                                      join(DepartureTimeDeref.timetable).\
                                      filter(Timetable.valid_from <= now_datetime.date() + single_day_delta).\
                                      filter(Timetable.valid_to >= now_datetime.date() + single_day_delta)


        bus = today_query.union_all(tomorrow_query).\
                            options(joinedload(DepartureTimeDeref.timetable, Timetable.route)).\
                            order_by("days_future").\
                            order_by(DepartureTimeDeref.time).\
                            first()

        if bus:
            bus = {'departure': bus[0].to_JSON(), 'days_future': int(bus[1])}

        mc.set(mc_key, bus, 30*24*60*60)

    if bus:
        departure = bus['departure']
        # Create a day delta, is this departure time today or tomorrow?
        day_delta = datetime.timedelta(days=bus['days_future'])
        departure_day = now_datetime.date() + day_delta

        departure_dt = datetime.datetime.combine(departure_day, departure['time'])
        # Add timezone infomation. pytz will handle DST correctly
        departure_dt = LONDON.localize(departure_dt)

        departure['time'] = departure_dt

        bus = departure

    return bus
Ejemplo n.º 18
0
    def timeseries(self, agg_unit, start, end, geom=None, column_filters=None):
        # Reading this blog post
        # http://no0p.github.io/postgresql/2014/05/08/timeseries-tips-pg.html
        # inspired this implementation.
        t = self.point_table

        if agg_unit == 'quarter':
            step = '3 months'
        else:
            step = '1 ' + agg_unit
        # Create a CTE to represent every time bucket in the timeseries
        # with a default count of 0
        day_generator = func.generate_series(func.date_trunc(agg_unit, start),
                                             func.date_trunc(agg_unit, end),
                                             step)
        defaults = select([sa.literal_column("0").label('count'),
                           day_generator.label('time_bucket')])\
            .alias('defaults')

        # Create a CTE that grabs the number of records
        # contained in each time bucket.
        # Will only have rows for buckets with records.
        where_filters = [t.c.point_date >= start,
                         t.c.point_date <= end]
        if column_filters:
            where_filters += column_filters

        actuals = select([func.count(t.c.hash).label('count'),
                          func.date_trunc(agg_unit, t.c.point_date).
                         label('time_bucket')])\
            .where(sa.and_(*where_filters))\
            .group_by('time_bucket')

        # Also filter by geometry if requested
        if geom:
            contains = func.ST_Within(t.c.geom, func.ST_GeomFromGeoJSON(geom))
            actuals = actuals.where(contains)

        # Need to alias to make it usable in a subexpression
        actuals = actuals.alias('actuals')

        # Outer join the default and observed values
        # to create the timeseries select statement.
        # If no observed value in a bucket, use the default.
        name = sa.literal_column("'{}'".format(self.dataset_name))\
            .label('dataset_name')
        bucket = defaults.c.time_bucket.label('time_bucket')
        count = func.coalesce(actuals.c.count, defaults.c.count).label('count')
        ts = select([name, bucket, count]).\
            select_from(defaults.outerjoin(actuals, actuals.c.time_bucket == defaults.c.time_bucket))

        return ts
Ejemplo n.º 19
0
 def _update_current_rev(self, old, new):
     if old == new:  # pragma: no cover
         return
     if new is None:  # pragma: no cover
         self.impl._exec(Version.__table__.delete().where(package=self.pkg_name))
     elif old is None:
         self.impl._exec(
             Version.__table__.insert().values(package=self.pkg_name, version_num=sqla.literal_column("'%s'" % new))
         )
     else:
         self.impl._exec(
             Version.__table__.update().values(package=self.pkg_name, version_num=sqla.literal_column("'%s'" % new))
         )
Ejemplo n.º 20
0
 def test_select_composition_six(self):
     # test that "auto-labeling of subquery columns"
     # doesn't interfere with literal columns,
     # exported columns don't get quoted
     self.assert_compile(
         select([
             literal_column("column1 AS foobar"),
             literal_column("column2 AS hoho"), table1.c.myid],
             from_obj=[table1]).select(),
         "SELECT column1 AS foobar, column2 AS hoho, myid FROM "
         "(SELECT column1 AS foobar, column2 AS hoho, "
         "mytable.myid AS myid FROM mytable)"
     )
Ejemplo n.º 21
0
 def test_select_composition_one(self):
     self.assert_compile(select(
         [
             literal_column("foobar(a)"),
             literal_column("pk_foo_bar(syslaal)")
         ],
         text("a = 12"),
         from_obj=[
             text("foobar left outer join lala on foobar.foo = lala.foo")
         ]
     ),
         "SELECT foobar(a), pk_foo_bar(syslaal) FROM foobar "
         "left outer join lala on foobar.foo = lala.foo WHERE a = 12"
     )
Ejemplo n.º 22
0
    def test_query_one(self):
        q = Session.query(User).\
                filter(User.name == 'ed').\
                    options(joinedload(User.addresses))

        q2 = serializer.loads(
                    serializer.dumps(q, -1),
                            users.metadata, Session)
        def go():
            eq_(q2.all(), [
                    User(name='ed', addresses=[Address(id=2),
                    Address(id=3), Address(id=4)])])

        self.assert_sql_count(testing.db, go, 1)

        eq_(q2.join(User.addresses).filter(Address.email
            == '*****@*****.**').value(func.count(literal_column('*'))), 1)
        u1 = Session.query(User).get(8)
        q = Session.query(Address).filter(Address.user
                == u1).order_by(desc(Address.email))
        q2 = serializer.loads(serializer.dumps(q, -1), users.metadata,
                              Session)
        eq_(q2.all(), [Address(email='*****@*****.**'),
            Address(email='*****@*****.**'),
            Address(email='*****@*****.**')])
Ejemplo n.º 23
0
 def test_endswith_literal_mysql(self):
     self.assert_compile(
         column('x').endswith(literal_column('y')),
         "x LIKE concat('%%', y)",
         checkparams={},
         dialect=mysql.dialect()
     )
def migrate_claims(migrate_engine, metadata, buildrequests, objects,
                   buildrequest_claims):

    # First, ensure there is an object row for each master
    null_id = sa.null().label('id')
    if migrate_engine.dialect.name == 'postgresql':
        # postgres needs NULL cast to an integer:
        null_id = sa.cast(null_id, sa.INTEGER)
    new_objects = sa.select([
        null_id,
        buildrequests.c.claimed_by_name.label("name"),
        sa.literal_column("'BuildMaster'").label("class_name"),
    ],
        whereclause=buildrequests.c.claimed_by_name != NULL,
        distinct=True)

    # this doesn't seem to work without str() -- verified in sqla 0.6.0 - 0.7.1
    migrate_engine.execute(
        str(sautils.InsertFromSelect(objects, new_objects)))

    # now make a buildrequest_claims row for each claimed build request
    join = buildrequests.join(objects,
                              (buildrequests.c.claimed_by_name == objects.c.name)
                              # (have to use sa.text because str, below, doesn't work
                              # with placeholders)
                              & (objects.c.class_name == sa.text("'BuildMaster'")))
    claims = sa.select([
        buildrequests.c.id.label('brid'),
        objects.c.id.label('objectid'),
        buildrequests.c.claimed_at,
    ], from_obj=[join],
        whereclause=buildrequests.c.claimed_by_name != NULL)
    migrate_engine.execute(
        str(sautils.InsertFromSelect(buildrequest_claims, claims)))
Ejemplo n.º 25
0
 def test_startswith_literal_mysql(self):
     self.assert_compile(
         column("x").startswith(literal_column("y")),
         "x LIKE concat(y, '%%')",
         checkparams={},
         dialect=mysql.dialect(),
     )
    def list_notifications(self, tenant_id, sort_by, offset, limit):

        rows = []

        with self._db_engine.connect() as conn:
            nm = self.nm

            select_nm_query = (select([nm])
                               .where(nm.c.tenant_id == bindparam('b_tenant_id')))

            parms = {'b_tenant_id': tenant_id}

            if sort_by is not None:
                order_columns = [literal_column(col) for col in sort_by]
                if 'id' not in sort_by:
                    order_columns.append(nm.c.id)
            else:
                order_columns = [nm.c.id]

            select_nm_query = select_nm_query.order_by(*order_columns)

            select_nm_query = (select_nm_query
                               .order_by(nm.c.id)
                               .limit(bindparam('b_limit')))

            parms['b_limit'] = limit + 1

            if offset:
                select_nm_query = select_nm_query.offset(bindparam('b_offset'))
                parms['b_offset'] = offset

            rows = conn.execute(select_nm_query, parms).fetchall()

        return [dict(row) for row in rows]
Ejemplo n.º 27
0
def update_market_history():
    session = DBSession()
    with transaction.manager:
        last_tick = session.query(func.max(MarketHistoryElement.ticks)).scalar()
        current_date = datetime.datetime.now(tzlocal())
        current_tick = int((current_date - MarketHistoryElement.START_ERA).total_seconds() / MarketHistoryElement.TICK_SECONDS_LENGTH)
        assert last_tick <= current_tick
        if last_tick == current_tick:
            logger.debug("Skipping update to market history: tick %d already saved.", current_tick)
            return
        origin_select = session.\
            query(Item.data_id,
                  literal_column(str(current_tick)),
                  Item.buy_count,
                  Item.buy_price,
                  Item.sell_count,
                  Item.sell_price).\
            filter(Item.buy_count > 0, Item.sell_count > 0)
        i = insert(MarketHistoryElement).from_select([
            MarketHistoryElement.item_id,
            MarketHistoryElement.ticks,
            MarketHistoryElement.buy_count,
            MarketHistoryElement.buy_price,
            MarketHistoryElement.sell_count,
            MarketHistoryElement.sell_price
        ], origin_select)
        logger.debug("Executing market history insert...")
        i.execute()
        logger.debug("Saved market data for tick %d.", current_tick)
Ejemplo n.º 28
0
def migrate_claims(migrate_engine, metadata, buildrequests, objects, buildrequest_claims):

    # First, ensure there is an object row for each master
    new_objects = sa.select(
        [
            sa.null().label("id"),
            buildrequests.c.claimed_by_name.label("name"),
            sa.literal_column("'BuildMaster'").label("class_name"),
        ],
        whereclause=buildrequests.c.claimed_by_name != None,
        distinct=True,
    )

    # this doesn't seem to work without str() -- verified in sqla 0.6.0 - 0.7.1
    migrate_engine.execute(str(sautils.InsertFromSelect(objects, new_objects)))

    # now make a buildrequest_claims row for each claimed build request
    join = buildrequests.join(
        objects,
        (buildrequests.c.claimed_by_name == objects.c.name)
        # (have to use sa.text because str, below, doesn't work
        # with placeholders)
        & (objects.c.class_name == sa.text("'BuildMaster'")),
    )
    claims = sa.select(
        [buildrequests.c.id.label("brid"), objects.c.id.label("objectid"), buildrequests.c.claimed_at],
        from_obj=[join],
        whereclause=buildrequests.c.claimed_by_name != None,
    )
    migrate_engine.execute(str(sautils.InsertFromSelect(buildrequest_claims, claims)))
Ejemplo n.º 29
0
 def test_contains_literal_concat(self):
     self.assert_compile(
         column('x').contains(literal_column('y')),
         "x LIKE concat(concat('%%', y), '%%')",
         checkparams={},
         dialect=mysql.dialect()
     )
Ejemplo n.º 30
0
def copr_cleanup(session, older_than):
    session.log.debug('Cleaning up old copr projects')
    interval = "now() - '{} month'::interval".format(older_than)
    to_delete_ids = session.db.query(CoprRebuildRequest.id)\
        .filter(CoprRebuildRequest.timestamp < literal_column(interval))\
        .all_flat()
    if to_delete_ids:
        rebuilds = session.db.query(CoprRebuild)\
            .filter(CoprRebuild.request_id.in_(to_delete_ids))\
            .filter(CoprRebuild.copr_build_id != None)\
            .all()
        for rebuild in rebuilds:
            try:
                copr_client.delete_project(
                    username=get_config('copr.copr_owner'),
                    projectname=rebuild.copr_name,
                )
            except CoprException as e:
                if 'does not exist' not in str(e):
                    session.log.warn("Cannot delete copr project {}: {}"
                                     .format(rebuild.copr_name, e))
                    if rebuild.request_id in to_delete_ids:
                        to_delete_ids.remove(rebuild.request_id)
    if to_delete_ids:
        session.db.query(CoprRebuildRequest)\
            .filter(CoprRebuildRequest.id.in_(to_delete_ids))\
            .delete()
        for request_id in to_delete_ids:
            shutil.rmtree(get_request_cachedir(request_id), ignore_errors=True)
        session.log_user_action(
            "Cleanup: Deleted {} copr requests"
            .format(len(to_delete_ids))
        )
Ejemplo n.º 31
0
 def test_startswith_sqlexpr(self):
     col = self.tables.some_table.c.data
     self._test(col.startswith(literal_column("'ab%c'")),
                {1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
Ejemplo n.º 32
0
 def test_startswith_literal(self):
     self.assert_compile(column('x').startswith(literal_column('y')),
                         "x LIKE y || '%%'",
                         checkparams={})
Ejemplo n.º 33
0
 def test_endswith_literal(self):
     self.assert_compile(column('x').endswith(literal_column('y')),
                         "x LIKE '%%' || y",
                         checkparams={})
 def test_pg_time_expression_literal_1y_grain(self):
     col = literal_column("COALESCE(a, b)")
     expr = PostgresEngineSpec.get_timestamp_expr(col, None, "P1Y")
     result = str(expr.compile(dialect=postgresql.dialect()))
     self.assertEqual(result, "DATE_TRUNC('year', COALESCE(a, b))")
Ejemplo n.º 35
0
 def _modify_query(query):
     m1(query.column_descriptions[0]["entity"])
     query = query.enable_assertions(False).filter(
         literal_column("1") == 1
     )
     return query
Ejemplo n.º 36
0
    def elevation(self, oid, segments=None, **params):
        if segments is not None and segments.isdigit():
            segments = int(segments)
            if segments > 500 or segments <= 0:
                segments = 500
        else:
            segments = 100

        ret = OrderedDict()
        ret['id'] = oid

        r = cherrypy.request.app.config['DB']['map'].tables.routes.data
        gen = sa.select([sa.func.generate_series(0, segments).label('i')
                         ]).alias()
        field = sa.func.ST_LineInterpolatePoint(r.c.geom,
                                                gen.c.i / float(segments))
        field = sa.func.ST_Collect(field)

        sel = sa.select([field]).where(r.c.id == oid)\
                .where(r.c.geom.ST_GeometryType() == 'ST_LineString')

        res = cherrypy.request.db.execute(sel).first()

        if res is not None and res[0] is not None:
            geom = to_shape(res[0])
            xcoord, ycoord = zip(*((p.x, p.y) for p in geom))
            geomlen = LineString(geom).length
            pos = [geomlen * i / float(segments) for i in range(segments)]
            compute_elevation(((xcoord, ycoord, pos), ), geom.bounds, ret)
            return ret

        # special treatment for multilinestrings
        sel = sa.select([r.c.geom,
                         sa.literal_column("""ST_Length2dSpheroid(ST_MakeLine(ARRAY[ST_Points(ST_Transform(geom,4326))]),
                             'SPHEROID[\"WGS 84\",6378137,298.257223563,AUTHORITY["EPSG",\"7030\"]]')"""),
                         r.c.geom.ST_NPoints()])\
                .where(r.c.id == oid)

        res = cherrypy.request.db.execute(sel).first()

        if res is not None and res[0] is not None:
            geom = to_shape(res[0])

            if res[2] > 10000:
                geom = geom.simplify(res[2] / 500, preserve_topology=False)
            elif res[2] > 4000:
                geom = geom.simplify(res[2] / 1000, preserve_topology=False)

            segments = []

            for seg in geom:
                p = seg.coords[0]
                xcoords = array('d', [p[0]])
                ycoords = array('d', [p[1]])
                pos = array('d')
                if segments:
                    prev = segments[-1]
                    pos.append(prev[2][-1] + \
                            Point(prev[0][-1], prev[1][-1]).distance(Point(*p)))
                else:
                    pos.append(0.0)
                for p in seg.coords[1:]:
                    pos.append(
                        pos[-1] +
                        Point(xcoords[-1], ycoords[-1]).distance(Point(*p)))
                    xcoords.append(p[0])
                    ycoords.append(p[1])

                segments.append((xcoords, ycoords, pos))

            compute_elevation(segments, geom.bounds, ret)

            ret['length'] = float(res[1])
            return ret

        raise cherrypy.NotFound()
Ejemplo n.º 37
0
 def test_contains_literal_concat(self):
     self.assert_compile(column('x').contains(literal_column('y')),
                         "x LIKE concat(concat('%%', y), '%%')",
                         checkparams={},
                         dialect=mysql.dialect())
Ejemplo n.º 38
0
 def test_endswith_sqlexpr(self):
     col = self.tables.some_table.c.data
     self._test(col.endswith(literal_column("'e%fg'")),
                {1, 2, 3, 4, 5, 6, 7, 8, 9})
Ejemplo n.º 39
0
    def test_insert_w_newlines(self, connection):
        from psycopg2 import extras

        t = self.tables.data

        ins = (t.insert().inline().values(
            id=bindparam("id"),
            x=select(literal_column("5")).select_from(
                self.tables.data).scalar_subquery(),
            y=bindparam("y"),
            z=bindparam("z"),
        ))
        # compiled SQL has a newline in it
        eq_(
            str(ins.compile(testing.db)),
            "INSERT INTO data (id, x, y, z) VALUES (%(id)s, "
            "(SELECT 5 \nFROM data), %(y)s, %(z)s)",
        )
        meth = extras.execute_values
        with mock.patch.object(extras, "execute_values",
                               side_effect=meth) as mock_exec:

            connection.execute(
                ins,
                [
                    {
                        "id": 1,
                        "y": "y1",
                        "z": 1
                    },
                    {
                        "id": 2,
                        "y": "y2",
                        "z": 2
                    },
                    {
                        "id": 3,
                        "y": "y3",
                        "z": 3
                    },
                ],
            )

        eq_(
            mock_exec.mock_calls,
            [
                mock.call(
                    mock.ANY,
                    "INSERT INTO data (id, x, y, z) VALUES %s",
                    (
                        {
                            "id": 1,
                            "y": "y1",
                            "z": 1
                        },
                        {
                            "id": 2,
                            "y": "y2",
                            "z": 2
                        },
                        {
                            "id": 3,
                            "y": "y3",
                            "z": 3
                        },
                    ),
                    template="(%(id)s, (SELECT 5 \nFROM data), %(y)s, %(z)s)",
                    fetch=False,
                    page_size=connection.dialect.executemany_values_page_size,
                )
            ],
        )
Ejemplo n.º 40
0
    def define_tables(cls, metadata):
        default_generator = cls.default_generator = {"x": 50}

        def mydefault():
            default_generator["x"] += 1
            return default_generator["x"]

        def myupdate_with_ctx(ctx):
            conn = ctx.connection
            return conn.execute(sa.select(sa.text("13"))).scalar()

        def mydefault_using_connection(ctx):
            conn = ctx.connection
            return conn.execute(sa.select(sa.text("12"))).scalar()

        use_function_defaults = testing.against("postgresql", "mssql")
        is_oracle = testing.against("oracle")

        class MyClass(object):
            @classmethod
            def gen_default(cls, ctx):
                return "hi"

        class MyType(TypeDecorator):
            impl = String(50)

            def process_bind_param(self, value, dialect):
                if value is not None:
                    value = "BIND" + value
                return value

        cls.f = 6
        cls.f2 = 11
        with testing.db.connect() as conn:
            currenttime = cls.currenttime = func.current_date(type_=sa.Date)
            if is_oracle:
                ts = conn.scalar(
                    sa.select(
                        func.trunc(
                            func.current_timestamp(),
                            sa.literal_column("'DAY'"),
                            type_=sa.Date,
                        )
                    )
                )
                currenttime = cls.currenttime = func.trunc(
                    currenttime, sa.literal_column("'DAY'"), type_=sa.Date
                )
                def1 = currenttime
                def2 = func.trunc(
                    sa.text("current_timestamp"),
                    sa.literal_column("'DAY'"),
                    type_=sa.Date,
                )

                deftype = sa.Date
            elif use_function_defaults:
                def1 = currenttime
                deftype = sa.Date
                if testing.against("mssql"):
                    def2 = sa.text("getdate()")
                else:
                    def2 = sa.text("current_date")
                ts = conn.scalar(func.current_date())
            else:
                def1 = def2 = "3"
                ts = 3
                deftype = Integer

            cls.ts = ts

        Table(
            "default_test",
            metadata,
            # python function
            Column("col1", Integer, primary_key=True, default=mydefault),
            # python literal
            Column(
                "col2",
                String(20),
                default="imthedefault",
                onupdate="im the update",
            ),
            # preexecute expression
            Column(
                "col3",
                Integer,
                default=func.length("abcdef"),
                onupdate=func.length("abcdefghijk"),
            ),
            # SQL-side default from sql expression
            Column("col4", deftype, server_default=def1),
            # SQL-side default from literal expression
            Column("col5", deftype, server_default=def2),
            # preexecute + update timestamp
            Column("col6", sa.Date, default=currenttime, onupdate=currenttime),
            Column("boolcol1", sa.Boolean, default=True),
            Column("boolcol2", sa.Boolean, default=False),
            # python function which uses ExecutionContext
            Column(
                "col7",
                Integer,
                default=mydefault_using_connection,
                onupdate=myupdate_with_ctx,
            ),
            # python builtin
            Column(
                "col8",
                sa.Date,
                default=datetime.date.today,
                onupdate=datetime.date.today,
            ),
            # combo
            Column("col9", String(20), default="py", server_default="ddl"),
            # python method w/ context
            Column("col10", String(20), default=MyClass.gen_default),
            # fixed default w/ type that has bound processor
            Column("col11", MyType(), default="foo"),
        )
 def test_select_composition_seven(self):
     self.assert_compile(
         select([literal_column('col1'),
                 literal_column('col2')],
                from_obj=table('tablename')).alias('myalias'),
         "SELECT col1, col2 FROM tablename")
Ejemplo n.º 42
0
 def test_contains_literal(self):
     self.assert_compile(column('x').contains(literal_column('y')),
                         "x LIKE '%%' || y || '%%'",
                         checkparams={})
Ejemplo n.º 43
0
def connection_block(field: ASTNode, parent_name: typing.Optional[str]) -> Alias:
    return_type = field.return_type
    sqla_model = return_type.sqla_model

    block_name = secure_random_string()
    if parent_name is None:
        join_conditions = [True]
    else:
        join_conditions = to_join_clause(field, parent_name)

    filter_conditions = to_conditions_clause(field)
    limit = to_limit(field)
    has_total = check_has_total(field)

    is_page_after = "after" in field.args
    is_page_before = "before" in field.args

    totalCount_alias = field.get_subfield_alias(["totalCount"])

    edges_alias = field.get_subfield_alias(["edges"])
    node_alias = field.get_subfield_alias(["edges", "node"])
    cursor_alias = field.get_subfield_alias(["edges", "cursor"])

    pageInfo_alias = field.get_subfield_alias(["pageInfo"])
    hasNextPage_alias = field.get_subfield_alias(["pageInfo", "hasNextPage"])
    hasPreviousPage_alias = field.get_subfield_alias(["pageInfo", "hasPreviousPage"])
    startCursor_alias = field.get_subfield_alias(["pageInfo", "startCursor"])
    endCursor_alias = field.get_subfield_alias(["pageInfo", "endCursor"])

    # Apply Filters
    core_model = sqla_model.__table__
    core_model_ref = (
        select(core_model.c)
        .select_from(core_model)
        .where(
            and_(
                # Join clause
                *join_conditions,
                # Conditions
                *filter_conditions,
            )
        )
    ).alias(block_name)

    new_edge_node_selects = []
    new_relation_selects = []

    for subfield in get_edge_node_fields(field):
        # Does anything other than NodeID go here?
        if subfield.return_type == ID:
            # elem = select([to_node_id_sql(sqla_model, core_model_ref)]).label(subfield.alias)
            elem = to_node_id_sql(sqla_model, core_model_ref).label(subfield.alias)
            new_edge_node_selects.append(elem)
        elif isinstance(subfield.return_type, (ScalarType, CompositeType, EnumType)):
            col_name = field_name_to_column(sqla_model, subfield.name).name
            elem = core_model_ref.c[col_name].label(subfield.alias)
            new_edge_node_selects.append(elem)
        else:
            elem = build_relationship(subfield, block_name)
            new_relation_selects.append(elem)

    # Setup Pagination
    args = field.args
    after_cursor = args.get("after", None)
    before_cursor = args.get("before", None)
    first = args.get("first", None)
    last = args.get("last", None)

    if first is not None and last is not None:
        raise ValueError('only one of "first" and "last" may be provided')

    if after_cursor or before_cursor:
        local_table_name = get_table_name(field.return_type.sqla_model)
        cursor_table_name = before_cursor.table_name if before_cursor else after_cursor.table_name
        cursor_values = before_cursor.values if before_cursor else after_cursor.values

        if after_cursor is not None and before_cursor is not None:
            raise ValueError('only one of "before" and "after" may be provided')

        if after_cursor is not None and last is not None:
            raise ValueError('"after" is not compatible with "last". Use "first"')

        if before_cursor is not None and first is not None:
            raise ValueError('"before" is not compatible with "first". Use "last"')

        if cursor_table_name != local_table_name:
            raise ValueError("Invalid cursor for entity type")

        pkey_cols = get_primary_key_columns(sqla_model)

        pagination_clause = tuple_(*[core_model_ref.c[col.name] for col in pkey_cols]).op(
            ">" if after_cursor is not None else "<"
        )(tuple_(*[cursor_values[col.name] for col in pkey_cols]))
    else:
        pagination_clause = True

    order_clause = [asc(core_model_ref.c[col.name]) for col in get_primary_key_columns(sqla_model)]
    reverse_order_clause = [desc(core_model_ref.c[col.name]) for col in get_primary_key_columns(sqla_model)]

    total_block = (
        select([func.count(ONE).label("total_count")]).select_from(core_model_ref.alias()).where(has_total)
    ).alias(block_name + "_total")

    node_id_sql = to_node_id_sql(sqla_model, core_model_ref)
    cursor_sql = to_cursor_sql(sqla_model, core_model_ref)

    # Select the right stuff
    p1_block = (
        select(
            [
                *new_edge_node_selects,
                *new_relation_selects,
                # For internal Use
                node_id_sql.label("_nodeId"),
                cursor_sql.label("_cursor"),
                # For internal Use
                func.row_number().over().label("_row_num"),
            ]
        )
        .select_from(core_model_ref)
        .where(pagination_clause)
        .order_by(*(reverse_order_clause if is_page_before else order_clause), *order_clause)
        .limit(cast(limit + 1, Integer()))
    ).alias(block_name + "_p1")

    # Drop maybe extra row
    p2_block = (
        select([*p1_block.c, (func.max(p1_block.c._row_num).over() > limit).label("_has_next_page")])
        .select_from(p1_block)
        .limit(limit)
    ).alias(block_name + "_p2")

    ordering = desc(literal_column("_row_num")) if is_page_before else asc(literal_column("_row_num"))

    p3_block = (select(p2_block.c).select_from(p2_block).order_by(ordering)).alias(block_name + "_p3")

    final = (
        select(
            [
                func.jsonb_build_object(
                    literal_string(totalCount_alias),
                    func.coalesce(func.min(total_block.c.total_count), ZERO) if has_total else None,
                    literal_string(pageInfo_alias),
                    func.jsonb_build_object(
                        literal_string(hasNextPage_alias),
                        func.coalesce(func.array_agg(p3_block.c._has_next_page)[ONE], FALSE),
                        literal_string(hasPreviousPage_alias),
                        TRUE if is_page_after else FALSE,
                        literal_string(startCursor_alias),
                        func.array_agg(p3_block.c._nodeId)[ONE],
                        literal_string(endCursor_alias),
                        func.array_agg(p3_block.c._nodeId)[func.array_upper(func.array_agg(p3_block.c._nodeId), ONE)],
                    ),
                    literal_string(edges_alias),
                    func.coalesce(
                        func.jsonb_agg(
                            func.jsonb_build_object(
                                literal_string(cursor_alias),
                                p3_block.c._nodeId,
                                literal_string(node_alias),
                                func.cast(func.row_to_json(literal_column(p3_block.name)), JSONB()),
                            )
                        ),
                        func.cast(literal("[]"), JSONB()),
                    ),
                ).label("ret_json")
            ]
        )
        .select_from(p3_block)
        .select_from(total_block if has_total else select([1]).alias())
    ).alias()

    return final
Ejemplo n.º 44
0
    def test_insert_modified_by_event(self, connection):
        from psycopg2 import extras

        t = self.tables.data

        ins = (
            t.insert()
            .inline()
            .values(
                id=bindparam("id"),
                x=select(literal_column("5"))
                .select_from(self.tables.data)
                .scalar_subquery(),
                y=bindparam("y"),
                z=bindparam("z"),
            )
        )
        # compiled SQL has a newline in it
        eq_(
            str(ins.compile(testing.db)),
            "INSERT INTO data (id, x, y, z) VALUES (%(id)s, "
            "(SELECT 5 \nFROM data), %(y)s, %(z)s)",
        )
        meth = extras.execute_batch
        with mock.patch.object(
            extras, "execute_values"
        ) as mock_values, mock.patch.object(
            extras, "execute_batch", side_effect=meth
        ) as mock_batch:

            # create an event hook that will change the statement to
            # something else, meaning the dialect has to detect that
            # insert_single_values_expr is no longer useful
            @event.listens_for(
                connection, "before_cursor_execute", retval=True
            )
            def before_cursor_execute(
                conn, cursor, statement, parameters, context, executemany
            ):
                statement = (
                    "INSERT INTO data (id, y, z) VALUES "
                    "(%(id)s, %(y)s, %(z)s)"
                )
                return statement, parameters

            connection.execute(
                ins,
                [
                    {"id": 1, "y": "y1", "z": 1},
                    {"id": 2, "y": "y2", "z": 2},
                    {"id": 3, "y": "y3", "z": 3},
                ],
            )

        eq_(mock_values.mock_calls, [])

        if connection.dialect.executemany_mode & EXECUTEMANY_BATCH:
            eq_(
                mock_batch.mock_calls,
                [
                    mock.call(
                        mock.ANY,
                        "INSERT INTO data (id, y, z) VALUES "
                        "(%(id)s, %(y)s, %(z)s)",
                        (
                            {"id": 1, "y": "y1", "z": 1},
                            {"id": 2, "y": "y2", "z": 2},
                            {"id": 3, "y": "y3", "z": 3},
                        ),
                    )
                ],
            )
        else:
            eq_(mock_batch.mock_calls, [])
Ejemplo n.º 45
0
def check_has_total(field: ASTNode) -> bool:
    "Check if 'totalCount' is requested in the query result set"
    return any(x.name in "totalCount" for x in field.fields)


def get_edge_node_fields(field):
    """Returns connection.edge.node fields"""
    for cfield in field.fields:
        if cfield.name == "edges":
            for edge_field in cfield.fields:
                if edge_field.name == "node":
                    return edge_field.fields
    return []


ONE = literal_column("1")
ZERO = literal_column("0")
TRUE = literal_column("true")
FALSE = literal_column("false")


def connection_block(field: ASTNode, parent_name: typing.Optional[str]) -> Alias:
    return_type = field.return_type
    sqla_model = return_type.sqla_model

    block_name = secure_random_string()
    if parent_name is None:
        join_conditions = [True]
    else:
        join_conditions = to_join_clause(field, parent_name)
Ejemplo n.º 46
0
    def _insert_version(self, version):
        assert version not in self.heads
        self.heads.add(version)

        self.context.impl._exec(self.context._version.insert().values(
            version_num=literal_column("'%s'" % version)))
Ejemplo n.º 47
0
    def setup_class(cls):
        global t, f, f2, ts, currenttime, metadata, default_generator

        db = testing.db
        metadata = MetaData(db)
        default_generator = {'x': 50}

        def mydefault():
            default_generator['x'] += 1
            return default_generator['x']

        def myupdate_with_ctx(ctx):
            conn = ctx.connection
            return conn.execute(sa.select([sa.text('13')])).scalar()

        def mydefault_using_connection(ctx):
            conn = ctx.connection
            try:
                return conn.execute(sa.select([sa.text('12')])).scalar()
            finally:
                # ensure a "close()" on this connection does nothing,
                # since its a "branched" connection
                conn.close()

        use_function_defaults = testing.against('postgresql', 'mssql')
        is_oracle = testing.against('oracle')

        class MyClass(object):
            @classmethod
            def gen_default(cls, ctx):
                return "hi"

        # select "count(1)" returns different results on different DBs also
        # correct for "current_date" compatible as column default, value
        # differences
        currenttime = func.current_date(type_=sa.Date, bind=db)
        if is_oracle:
            ts = db.scalar(
                sa.select([
                    func.trunc(func.sysdate(),
                               sa.literal_column("'DAY'"),
                               type_=sa.Date).label('today')
                ]))
            assert isinstance(
                ts, datetime.date) and not isinstance(ts, datetime.datetime)
            f = sa.select([func.length('abcdef')], bind=db).scalar()
            f2 = sa.select([func.length('abcdefghijk')], bind=db).scalar()
            # TODO: engine propigation across nested functions not working
            currenttime = func.trunc(currenttime,
                                     sa.literal_column("'DAY'"),
                                     bind=db,
                                     type_=sa.Date)
            def1 = currenttime
            def2 = func.trunc(sa.text("sysdate"),
                              sa.literal_column("'DAY'"),
                              type_=sa.Date)

            deftype = sa.Date
        elif use_function_defaults:
            f = sa.select([func.length('abcdef')], bind=db).scalar()
            f2 = sa.select([func.length('abcdefghijk')], bind=db).scalar()
            def1 = currenttime
            deftype = sa.Date
            if testing.against('mssql'):
                def2 = sa.text("getdate()")
            else:
                def2 = sa.text("current_date")
            ts = db.scalar(func.current_date())
        else:
            f = len('abcdef')
            f2 = len('abcdefghijk')
            def1 = def2 = "3"
            ts = 3
            deftype = Integer

        t = Table(
            'default_test1',
            metadata,
            # python function
            Column('col1', Integer, primary_key=True, default=mydefault),

            # python literal
            Column('col2',
                   String(20),
                   default="imthedefault",
                   onupdate="im the update"),

            # preexecute expression
            Column('col3',
                   Integer,
                   default=func.length('abcdef'),
                   onupdate=func.length('abcdefghijk')),

            # SQL-side default from sql expression
            Column('col4', deftype, server_default=def1),

            # SQL-side default from literal expression
            Column('col5', deftype, server_default=def2),

            # preexecute + update timestamp
            Column('col6', sa.Date, default=currenttime, onupdate=currenttime),
            Column('boolcol1', sa.Boolean, default=True),
            Column('boolcol2', sa.Boolean, default=False),

            # python function which uses ExecutionContext
            Column('col7',
                   Integer,
                   default=mydefault_using_connection,
                   onupdate=myupdate_with_ctx),

            # python builtin
            Column('col8',
                   sa.Date,
                   default=datetime.date.today,
                   onupdate=datetime.date.today),
            # combo
            Column('col9', String(20), default='py', server_default='ddl'),

            # python method w/ context
            Column('col10', String(20), default=MyClass.gen_default))

        t.create()
Ejemplo n.º 48
0
def literal_string(text):
    return literal_column(f"'{text}'")
Ejemplo n.º 49
0
        ops.Power: fixed_arity(sa.func._ibis_sqlite_power, 2),
        ops.Exp: unary(sa.func._ibis_sqlite_exp),
        ops.Ln: unary(sa.func._ibis_sqlite_ln),
        ops.Log: _log,
        ops.Log10: unary(sa.func._ibis_sqlite_log10),
        ops.Log2: unary(sa.func._ibis_sqlite_log2),
        ops.Floor: unary(sa.func._ibis_sqlite_floor),
        ops.Ceil: unary(sa.func._ibis_sqlite_ceil),
        ops.Sign: unary(sa.func._ibis_sqlite_sign),
        ops.FloorDivide: fixed_arity(sa.func._ibis_sqlite_floordiv, 2),
        ops.Modulus: fixed_arity(sa.func._ibis_sqlite_mod, 2),
        ops.Variance: _variance_reduction('_ibis_sqlite_var'),
        ops.StandardDev: toolz.compose(
            sa.func._ibis_sqlite_sqrt, _variance_reduction('_ibis_sqlite_var')
        ),
        ops.RowID: lambda t, expr: sa.literal_column('rowid'),
    }
)


def add_operation(op, translation_func):
    _operation_registry[op] = translation_func


class SQLiteExprTranslator(alch.AlchemyExprTranslator):

    _registry = _operation_registry
    _rewrites = alch.AlchemyExprTranslator._rewrites.copy()
    _type_map = alch.AlchemyExprTranslator._type_map.copy()
    _type_map.update({dt.Double: sa.types.REAL, dt.Float: sa.types.REAL})
Ejemplo n.º 50
0
        conn.scalar(raw_sql)  # $ SPURIOUS: getSql=raw_sql
    except sqlalchemy.exc.ObjectNotExecutableError:
        pass
    scalar_result = conn.scalar(text_sql)  # $ getSql=text_sql
    assert scalar_result == "FOO"
    scalar_result = conn.scalar(statement=text_sql)  # $ getSql=text_sql
    assert scalar_result == "FOO"

    # This is a contrived example
    select = sqlalchemy.select(
        sqlalchemy.text("'BAR'"))  # $ constructedSql="'BAR'"
    result = conn.execute(select)  # $ getSql=select
    assert result.fetchall() == [("BAR", )]

    # This is a contrived example
    select = sqlalchemy.select(sqlalchemy.literal_column("'BAZ'"))
    result = conn.execute(select)  # $ getSql=select
    assert result.fetchall() == [("BAZ", )]

with future_engine.connect() as conn:
    result = conn.execute(text_sql)  # $ getSql=text_sql
    assert result.fetchall() == [("FOO", )]

# `begin` returns a new Connection object with a transaction begun.
print("v2.0 engine.begin")
with engine.begin() as conn:
    result = conn.execute(text_sql)  # $ getSql=text_sql
    assert result.fetchall() == [("FOO", )]

# construction by object
conn = sqlalchemy.future.Connection(engine)
Ejemplo n.º 51
0
 def test_endswith_literal_mysql(self):
     self.assert_compile(column('x').endswith(literal_column('y')),
                         "x LIKE concat('%%', y)",
                         checkparams={},
                         dialect=mysql.dialect())
Ejemplo n.º 52
0
async def test_operations_on_group_classifiers(
    pg_engine: Engine, classifiers_bundle: Dict
):
    # NOTE: mostly for TDD
    async with pg_engine.acquire() as conn:

        # creates a group
        stmt = (
            groups.insert()
            .values(**random_group(name="MyGroup"))
            .returning(groups.c.gid)
        )
        gid = await conn.scalar(stmt)

        # adds classifiers to a group
        stmt = (
            group_classifiers.insert()
            .values(bundle=classifiers_bundle, gid=gid)
            .returning(literal_column("*"))
        )
        result = await conn.execute(stmt)
        row = await result.first()

        assert row
        assert row[group_classifiers.c.gid] == gid
        assert row[group_classifiers.c.bundle] == classifiers_bundle

        # get bundle in one query
        bundle = await conn.scalar(
            sa.select([group_classifiers.c.bundle]).where(
                group_classifiers.c.gid == gid
            )
        )
        assert bundle
        assert classifiers_bundle == bundle

        # Cannot add more than one classifier's bundle to the same group
        # pylint: disable=no-member
        with pytest.raises(psycopg2.errors.UniqueViolation):
            await conn.execute(group_classifiers.insert().values(bundle={}, gid=gid))

        # deleting a group deletes the classifier
        await conn.execute(groups.delete().where(groups.c.gid == gid))

        groups_count = await conn.scalar(sa.select(func.count(groups.c.gid)))
        classifiers_count = await conn.scalar(
            sa.select([func.count()]).select_from(group_classifiers)
        )

        assert (
            groups_count == 1
        ), "There should be only the Everyone group in the database!"
        assert classifiers_count <= groups_count
        assert classifiers_count == 0

        # no bundle
        bundle = await conn.scalar(
            sa.select([group_classifiers.c.bundle]).where(
                group_classifiers.c.gid == gid
            )
        )
        assert bundle is None
Ejemplo n.º 53
0
import pytest
import sqlalchemy as sa

from db_mock import DBMockSync
from db_mock.utils import _normalize

QUERIES = [
    "select * from users",
    sa.select([sa.literal_column("*")]).select_from(sa.table("users")),
]


@pytest.fixture(params=QUERIES)
def query(request):
    return request.param


def test_normalize(query):
    import sqlalchemy.engine.default as default

    assert _normalize(query, default.DefaultDialect()) == ("""SELECT *
FROM users""")


def test_blocking(query):
    mock = (DBMockSync().expect_stmt(query).expect_returns([
        (1, "User 1"), (2, "User 2")
    ]).expect_stmt(
        "update users set password = '******' where id = 1;").expect_rowcount(
            1).expect_stmt("select count(*) from users").expect_returns([
                (5, )
Ejemplo n.º 54
0
def create_aux_pfamA_pfamseqid():
    from sqlalchemy.orm import aliased

    tablename = "pfamjoinpfamseq"
    m = MetaData(engine)
    if not engine.dialect.has_table(engine, tablename):
        # t.drop(engine) # to delete/drop table
        t = Table(tablename, m,
                  Column('pfamseq_id', String(40), primary_key=True),
                  Column('pfamA_acc', String(7), primary_key=True, index=True),
                  Column('pfamseq_acc', String(40)),
                  Column('has_pdb', Boolean, unique=False, default=False))
        t.create(engine)

        pfams = db.query(PfamA.pfamA_acc).all()
        for pf in pfams:
            # for pf in ["PF00451", "PF00452", "PF00453"]:
            print pf[0]
            pfam = pf[0]
            # pfam = pf

            query = db.query(
                concat(Pfamseq.pfamseq_id, '/',
                       cast(PfamARegFullSignificant.seq_start,
                            types.Unicode), '-',
                       cast(PfamARegFullSignificant.seq_end,
                            types.Unicode)), PfamARegFullSignificant.pfamA_acc,
                PfamARegFullSignificant.pfamseq_acc)

            query = query.join(
                PfamARegFullSignificant,
                and_(
                    PfamARegFullSignificant.in_full == 1,
                    Pfamseq.pfamseq_acc == PfamARegFullSignificant.pfamseq_acc,
                    PfamARegFullSignificant.pfamA_acc == pfam))

            ### add column with pdb
            subquery2 = db.query(PdbPfamAReg.pfamseq_acc)
            subquery2 = subquery2.filter(
                PdbPfamAReg.pfamA_acc == pfam).distinct().subquery()
            query_pdb = query.filter(
                PfamARegFullSignificant.pfamseq_acc == subquery2.c.pfamseq_acc)
            subquery_pdb = query_pdb.subquery()

            query = query.filter(
                PfamARegFullSignificant.pfamseq_acc.notin_(subquery2))

            # print query.add_columns(literal_column("0").label("has_pdb")).distinct().all()
            # print "################"
            # print "################"
            # print query_pdb.add_columns(literal_column("1").label("has_pdb")).distinct().all()

            query = query.add_columns(
                literal_column("0").label("has_pdb")).distinct()
            query_pdb = query_pdb.add_columns(
                literal_column("1").label("has_pdb")).distinct()
            query_union = query.union(query_pdb)
            if query_union.first():
                engine.execute(t.insert().values(tuple(query_union.all())))
        return ["Succesfully created pfamjoinpfamseq"]

    else:
        pfamjoinpfamseq = Table(tablename, m, autoload=True)
        mapper(Pfamjoinpfamseq, pfamjoinpfamseq)
        return ["pfamjoinpfamseq table exists, creating mapper"]
Ejemplo n.º 55
0
def read_all(
    isActive=None,
    isFavourite=None,
    isSandbox=None,
    namesonly=None,
    page=None,
    page_size=None,
    sort=None,
):
    """
    This function responds to a request for /api/solutions
    with the complete lists of solutions

    :return:        json string of list of solutions
    """
    logger.debug("solution.read_all")
    logger.debug(
        "Parameters: isActive: %s, isFavourite: %s, isSandbox: %s, namesonly: %s, page: %s, page_size: %s, sort: %s",
        isActive,
        isFavourite,
        isSandbox,
        namesonly,
        page,
        page_size,
        sort,
    )
    with db_session() as dbs:
        # pre-process sort instructions
        if sort is None:
            solution_query = dbs.query(Solution).order_by(Solution.id)
        else:
            try:
                sort_inst = [si.split(":") for si in sort]
                orderby_arr = []
                for si in sort_inst:
                    si1 = si[0]
                    if len(si) > 1:
                        si2 = si[1]
                    else:
                        si2 = "asc"
                    orderby_arr.append(f"{si1} {si2}")
                # print("orderby: {}".format(orderby_arr))
                solution_query = dbs.query(Solution).order_by(
                    literal_column(", ".join(orderby_arr)))
            except SQLAlchemyError as e:
                logger.warning("Exception: %s", e)
                solution_query = dbs.query(Solution).order_by(Solution.id)

        # Create the list of solutions from our data
        solution_query = solution_query.filter(
            (isActive is None or Solution.isActive == isActive),
            (isFavourite is None or Solution.isFavourite == isFavourite),
            (isSandbox is None or Solution.isSandbox == isSandbox),
        )

        # do limit and offset last
        if page is None or page_size is None:
            solutions = solution_query.all()
        else:
            solutions = solution_query.limit(page_size).offset(page *
                                                               page_size)

        if namesonly is True:
            # Serialize the data for the response
            schema = SolutionNamesOnlySchema(many=True)
            data = schema.dump(solutions)
        else:
            for sol in solutions:
                sol = solution_extension.expand_solution(sol, dbsession=dbs)
            schema = ExtendedSolutionSchema(many=True)
            data = schema.dump(solutions)

        logger.debug("read_all: %s", data)
        return data, 200