def datatype_records_to_subset_and_migrate(likechars):
    stmt_for_pkeys = conn_popler_2.execute(
        select(
            from_obj=Maintable,
            columns=[
                column('lter_proj_site'),
                column('samplingprotocol')
            ]).
        where(
            column('samplingprotocol').like(
                '%{}%'.format(likechars))
        )
    )
    data = DataFrame(stmt_for_pkeys.fetchall())
    data.columns = stmt_for_pkeys.keys()

    records_to_get = data['lter_proj_site'].values.tolist()

    stmt_for_records = conn_popler_2.execute(
        select(
            from_table=Rawtable,
        ).
        where(column('lter_proj_site').in_(records_to_get)).
        order_by('sampleid')
    )
    data2 = DataFrame(stmt_for_records.fetchall())
    data2.columns = stmt_for_records.keys()
    data2.drop('individ', axis=1, inplace=True)
예제 #2
0
    def test_legacy_typemap(self):
        table1 = table(
            "mytable",
            column("myid", Integer),
            column("name", String),
            column("description", String),
        )
        with testing.expect_deprecated(
            "The text.typemap parameter is deprecated"
        ):
            t = text(
                "select id, name from user",
                typemap=dict(id=Integer, name=String),
            )

        stmt = select([table1.c.myid]).select_from(
            table1.join(t, table1.c.myid == t.c.id)
        )
        compiled = stmt.compile()
        eq_(
            compiled._create_result_map(),
            {
                "myid": (
                    "myid",
                    (table1.c.myid, "myid", "myid"),
                    table1.c.myid.type,
                )
            },
        )
예제 #3
0
 def test_unconsumed_names_kwargs(self):
     t = table("t", column("x"), column("y"))
     assert_raises_message(
         exc.CompileError,
         "Unconsumed column names: z",
         t.insert().values(x=5, z=5).compile,
     )
예제 #4
0
    def test_no_table_needs_pl(self):
        Subset = self.classes.Subset

        selectable = select([column("x"), column("y"), column("z")]).alias()
        assert_raises_message(
            sa.exc.ArgumentError, "could not assemble any primary key columns", mapper, Subset, selectable
        )
def downgrade():
    # FIXME: this adds extraneous commas
    return
    log = sa.table('log', sa.column('type', sa.String), sa.column('msg', sa.String))
    rows = op.get_bind().execute(log.select().where(log.c.type == 'kick')).fetchall()
    values = [{'old_msg': x.msg, 'msg': x.msg.replace(' ', ',', 1)} for x in rows]
    op.get_bind().execute(log.update().where(log.c.msg == sa.bindparam('old_msg')).values(msg=sa.bindparam('msg')), values)
예제 #6
0
    def execute(self, connection, filter_values):
        max_date_query = sqlalchemy.select([
            sqlalchemy.func.max(sqlalchemy.column('completed_on')).label('completed_on'),
            sqlalchemy.column('case_id').label('case_id')
        ]).select_from(sqlalchemy.table(self.table_name))

        if self.filters:
            for filter in self.filters:
                max_date_query.append_whereclause(filter.build_expression())

        max_date_query.append_group_by(
            sqlalchemy.column('case_id')
        )

        max_date_subquery = sqlalchemy.alias(max_date_query, 'max_date')

        asha_table = self.get_asha_table_name()
        checklist_query = sqlalchemy.select()
        for column in self.columns:
            checklist_query.append_column(column.build_column())

        checklist_query = checklist_query.where(
            sqlalchemy.literal_column('"{}".case_id'.format(asha_table)) == max_date_subquery.c.case_id
        ).where(
            sqlalchemy.literal_column('"{}".completed_on'.format(asha_table)) == max_date_subquery.c.completed_on
        ).select_from(sqlalchemy.table(asha_table))

        return connection.execute(checklist_query, **filter_values).fetchall()
예제 #7
0
def upgrade():
    ### commands auto generated by Alembic - please adjust! ###
    subsc = op.create_table(
        'subscription',
        sa.Column('id', sa.Integer(), nullable=False),
        sa.Column('user_id', sa.Integer(), sa.ForeignKey('user.id'), nullable=False),
        sa.Column('feed_id', sa.Integer(), sa.ForeignKey('feed.id'), nullable=False),
        sa.Column('name', sa.String(length=256), nullable=True),
        sa.Column('tags', sa.String(length=256), nullable=True),
        sa.PrimaryKeyConstraint('id')
    )

    feed = sa.table(
        'feed',
        sa.column('id', sa.Integer()),
        sa.column('name', sa.String()))

    u2f = sa.table(
        'users_to_feeds',
        sa.column('user_id', sa.Integer()),
        sa.column('feed_id', sa.Integer()))

    values = sa.select(
        [u2f.c.user_id, u2f.c.feed_id, feed.c.name]
    ).select_from(
        u2f.join(feed, feed.c.id == u2f.c.feed_id)
    )

    op.execute(subsc.insert().from_select(
        ['user_id', 'feed_id', 'name'], values))

    op.drop_table('users_to_feeds')
예제 #8
0
def filtra_por_partido(_partido):
    """retorna dados do partido passado por parâmetro"""
    _select = select([column('nome_doador'), column('nome_candidato'), column('partido'), column('valor')]).\
        select_from(dados_agrup_doadores).where(column('partido') == _partido)
    # retorna resultado filtrado e count do resuldato

    return [dict(doador=x[0], candidato=x[1], partido=x[2], valor=x[3]) for x in _select.execute()],\
           list(_select.count().execute())[0][0]
def upgrade():
    log = sa.table('log', sa.column('type', sa.String), sa.column('msg', sa.String))
    rows = op.get_bind().execute(log.select().where(log.c.type == 'kick').where(log.c.msg.like('%,%'))).fetchall()
    rows = [x for x in rows if ',' in x.msg and x.msg.find(',') < x.msg.find(' ')]
    if not rows:
        return
    values = [{'old_msg': x.msg, 'msg': x.msg.replace(',', ' ', 1)} for x in rows]
    op.get_bind().execute(log.update().where(log.c.msg == sa.bindparam('old_msg')).values(msg=sa.bindparam('msg')), values)
예제 #10
0
    def test_unconsumed_names_kwargs_w_keys(self):
        t = table("t", column("x"), column("y"))

        assert_raises_message(
            exc.CompileError,
            "Unconsumed column names: j",
            t.update().values(x=5, j=7).compile,
            column_keys=['j']
        )
def upgrade():
    source = sa.table('source', sa.column('id'), sa.column('updated', sa.DateTime))
    dt = sa.bindparam('dt', UPDATED)
    touch = sa.update(source, bind=op.get_bind())\
        .where(source.c.id == sa.bindparam('id_'))\
        .where(source.c.updated < dt)\
        .values(updated=dt)

    for id_ in IDS:
        touch.execute(id_=id_)
예제 #12
0
def test_compile_with_one_unnamed_table():
    t = ibis.table([('a', 'string')])
    s = ibis.table([('b', 'string')], name='s')
    join = t.join(s, t.a == s.b)
    result = ibis.sqlite.compile(join)
    sqla_t = sa.table('t0', sa.column('a', sa.String)).alias('t0')
    sqla_s = sa.table('s', sa.column('b', sa.String)).alias('t1')
    sqla_join = sqla_t.join(sqla_s, sqla_t.c.a == sqla_s.c.b)
    expected = sa.select([sqla_t.c.a, sqla_s.c.b]).select_from(sqla_join)
    assert str(result) == str(expected)
예제 #13
0
    def test_fewer_cols_than_sql_positional(self):
        c1, c2 = column('q'), column('p')
        stmt = text("select a, b, c, d from text1").columns(c1, c2)

        # no warning as this can be similar for non-positional
        result = testing.db.execute(stmt)
        row = result.first()

        eq_(row[c1], "a1")
        eq_(row["c"], "c1")
예제 #14
0
    def test_dupe_col_obj(self):
        c1, c2, c3 = column('q'), column('p'), column('r')
        stmt = text("select a, b, c, d from text1").columns(c1, c2, c3, c2)

        assert_raises_message(
            exc.InvalidRequestError,
            "Duplicate column expression requested in "
            "textual SQL: <.*.ColumnClause.*; p>",
            testing.db.execute, stmt
        )
예제 #15
0
    def test_unconsumed_names_values_dict(self):
        t = table("t", column("x"), column("y"))
        t2 = table("t2", column("q"), column("z"))

        assert_raises_message(
            exc.CompileError,
            "Unconsumed column names: j",
            t.update().values(x=5, j=7).values({t2.c.z: 5}).
            where(t.c.x == t2.c.q).compile,
        )
예제 #16
0
    def test_compare_labels(self):
        is_true(column("q").label(None).compare(column("q").label(None)))

        is_false(column("q").label("foo").compare(column("q").label(None)))

        is_false(column("q").label(None).compare(column("q").label("foo")))

        is_false(column("q").label("foo").compare(column("q").label("bar")))

        is_true(column("q").label("foo").compare(column("q").label("foo")))
예제 #17
0
    def test_functions_args_noname(self):
        class myfunc(FunctionElement):
            pass

        @compiles(myfunc)
        def visit_myfunc(element, compiler, **kw):
            return "myfunc%s" % (compiler.process(element.clause_expr, **kw),)

        self.assert_compile(myfunc(), "myfunc()")

        self.assert_compile(myfunc(column("x"), column("y")), "myfunc(x, y)")
예제 #18
0
    def test_bindparam_name_no_consume_error(self):
        t = table("t", column("x"), column("y"))
        # bindparam names don't get counted
        i = t.insert().values(x=3 + bindparam("x2"))
        self.assert_compile(i, "INSERT INTO t (x) VALUES ((:param_1 + :x2))")

        # even if in the params list
        i = t.insert().values(x=3 + bindparam("x2"))
        self.assert_compile(
            i, "INSERT INTO t (x) VALUES ((:param_1 + :x2))", params={"x2": 1}
        )
예제 #19
0
 def __table_args__(cls):
     return (
         Index(
             'recipient_address_in_group', 'address', 'group',
             unique=True, postgresql_where=column('group') != None
         ),
         Index(
             'recipient_address_without_group', 'address',
             unique=True, postgresql_where=column('group') == None
         ),
     )
예제 #20
0
    def test_no_tables(self):
        Subset = self.classes.Subset

        selectable = select([column("x"), column("y"), column("z")]).alias()
        mapper(Subset, selectable, primary_key=[selectable.c.x])

        self.assert_compile(
            Session().query(Subset),
            "SELECT anon_1.x AS anon_1_x, anon_1.y AS anon_1_y, "
            "anon_1.z AS anon_1_z FROM (SELECT x, y, z) AS anon_1",
            use_default_dialect=True
        )
예제 #21
0
    def test_binds_in_select(self):
        t = table("t", column("a"), column("b"), column("c"))

        @compiles(BindParameter)
        def gen_bind(element, compiler, **kw):
            return "BIND(%s)" % compiler.visit_bindparam(element, **kw)

        self.assert_compile(
            t.select().where(t.c.c == 5),
            "SELECT t.a, t.b, t.c FROM t WHERE t.c = BIND(:c_1)",
            use_default_dialect=True,
        )
예제 #22
0
    def test_binds_in_dml(self):
        t = table("t", column("a"), column("b"), column("c"))

        @compiles(BindParameter)
        def gen_bind(element, compiler, **kw):
            return "BIND(%s)" % compiler.visit_bindparam(element, **kw)

        self.assert_compile(
            t.insert(),
            "INSERT INTO t (a, b) VALUES (BIND(:a), BIND(:b))",
            {"a": 1, "b": 2},
            use_default_dialect=True,
        )
예제 #23
0
    def test_labels_no_collision(self):

        t = table("foo", column("id"), column("foo_id"))

        self.assert_compile(
            t.update().where(t.c.id == 5),
            "UPDATE foo SET id=:id, foo_id=:foo_id WHERE foo.id = :id_1",
        )

        self.assert_compile(
            t.update().where(t.c.id == bindparam(key=t.c.id._label)),
            "UPDATE foo SET id=:id, foo_id=:foo_id WHERE foo.id = :foo_id_1",
        )
예제 #24
0
def upgrade():
    op.alter_column(
        'writeup_post_versions', 'threadpost_id',
        nullable=False)
    op.alter_column(
        'writeup_post_versions', 'writeuppost_id',
        nullable=True)
    op.alter_column(
        'writeup_post_versions', 'html',
        nullable=True)
    op.create_check_constraint(
        'writeup_post_versions_check_html', 'writeup_post_versions',
        sa.and_(sa.column('writeuppost_id') != sa.null(), sa.column('html') != sa.null()))
예제 #25
0
    def test_select(self):
        t1 = table("t1", column("c1"), column("c2"))

        @compiles(Select, "sqlite")
        def compile_(element, compiler, **kw):
            return "OVERRIDE"

        s1 = select([t1])
        self.assert_compile(s1, "SELECT t1.c1, t1.c2 FROM t1")

        from sqlalchemy.dialects.sqlite import base as sqlite

        self.assert_compile(s1, "OVERRIDE", dialect=sqlite.dialect())
예제 #26
0
파일: db.py 프로젝트: agdsn/hades
def lock_table(connection: Connection, target_table: Table):
    """
    Lock a table using a PostgreSQL advisory lock

    The OID of the table in the pg_class relation is used as lock id.
    :param connection: DB connection
    :param target_table: Table object
    """
    logger.debug('Locking table "%s"', target_table.name)
    oid = connection.execute(select([column("oid")])
                             .select_from(table("pg_class"))
                             .where((column("relname") == target_table.name))
                             ).scalar()
    connection.execute(select([func.pg_advisory_xact_lock(oid)])).scalar()
def get_data_with_metakey(id_number):
    stmt = conn.execute(
        select(column_objs).
        select_from(
            Rawtable.__table__.
            join(Taxatable.__table__).
            join(Maintable.__table__).
            join(Sitetable)).
        where(column('metarecordid') == id_number).
        order_by(column('sampleid'))
    )
    data = DataFrame(stmt.fetchall())
    data.columns = stmt.keys()
    return data
def upgrade():
    op.add_column('writeup_post_versions', sa.Column('edit_summary', sa.Unicode(length=200), nullable=True))
    from mimir.models import AwareDateTime
    t_wpv = sa.table(
        'writeup_post_versions',
        sa.column('edit_summary', sa.Unicode(200)),
        sa.column('extracted_at', AwareDateTime),
    )
    stmt = t_wpv.update().values(
        edit_summary="Extracted at " + sa.cast(t_wpv.c.extracted_at, sa.Unicode)
    )
    op.get_bind().execute(stmt)
    op.alter_column('writeup_post_versions', 'edit_summary', nullable=False)
    op.alter_column('writeup_post_versions', 'extracted_at', new_column_name='created_at')
예제 #29
0
    def test_via_column(self):
        c1, c2, c3, c4 = column('q'), column('p'), column('r'), column('d')
        stmt = text("select a, b, c, d from text1").columns(c1, c2, c3, c4)

        result = testing.db.execute(stmt)
        row = result.first()

        eq_(row[c2], "b1")
        eq_(row[c4], "d1")
        eq_(row[1], "b1")
        eq_(row["b"], "b1")
        eq_(row.keys(), ["a", "b", "c", "d"])
        eq_(row["r"], "c1")
        eq_(row["d"], "d1")
예제 #30
0
    def test_mssql_where_clause_n_prefix(self):
        dialect = pymssql.dialect()
        spec = MssqlEngineSpec
        str_col = column('col', type_=spec.get_sqla_column_type('VARCHAR(10)'))
        unicode_col = column('unicode_col', type_=spec.get_sqla_column_type('NTEXT'))
        tbl = table('tbl')
        sel = select([str_col, unicode_col]).\
            select_from(tbl).\
            where(str_col == 'abc').\
            where(unicode_col == 'abc')

        query = str(sel.compile(dialect=dialect, compile_kwargs={'literal_binds': True}))
        query_expected = "SELECT col, unicode_col \nFROM tbl \nWHERE col = 'abc' AND unicode_col = N'abc'"  # noqa
        self.assertEqual(query, query_expected)
예제 #31
0
 def test_mssql_time_expression_mixed_case_column_1y_grain(self):
     col = column('MixedCase')
     expr = MssqlEngineSpec.get_timestamp_expr(col, None, 'P1Y')
     result = str(expr.compile(dialect=mssql.dialect()))
     self.assertEqual(result,
                      'DATEADD(year, DATEDIFF(year, 0, [MixedCase]), 0)')
예제 #32
0
    def get_export_data(cls):
        if cls.__name__ == "Payment":
            # Export stats for each payment type separately
            return {}

        purchase_counts = (
            cls.query.outerjoin(cls.purchases)
            .group_by(cls.id)
            .with_entities(func.count(Ticket.id))
        )
        refund_counts = (
            cls.query.outerjoin(cls.refunds)
            .group_by(cls.id)
            .with_entities(func.count(Refund.id))
        )

        cls_version = version_class(cls)
        cls_transaction = transaction_class(cls)
        changes = cls.query.join(cls.versions).group_by(cls.id)
        change_counts = changes.with_entities(func.count(cls_version.id))
        first_changes = (
            changes.join(cls_version.transaction)
            .with_entities(func.min(cls_transaction.issued_at).label("created"))
            .from_self()
        )

        cls_ver_new = aliased(cls.versions)
        cls_ver_paid = aliased(cls.versions)
        cls_txn_new = aliased(cls_version.transaction)
        cls_txn_paid = aliased(cls_version.transaction)
        active_time = func.max(cls_txn_paid.issued_at) - func.max(cls_txn_new.issued_at)
        active_times = (
            cls.query.join(cls_ver_new, cls_ver_new.id == cls.id)
            .join(cls_ver_paid, cls_ver_paid.id == cls.id)
            .join(cls_txn_new, cls_txn_new.id == cls_ver_new.transaction_id)
            .join(cls_txn_paid, cls_txn_paid.id == cls_ver_paid.transaction_id)
            .filter(cls_ver_new.state == "new")
            .filter(cls_ver_paid.state == "paid")
            .with_entities(active_time.label("active_time"))
            .group_by(cls.id)
        )

        time_buckets = [timedelta(0), timedelta(minutes=1), timedelta(hours=1)] + [
            timedelta(d)
            for d in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 28, 60]
        ]

        data = {
            "public": {
                "payments": {
                    "counts": {
                        "purchases": bucketise(
                            purchase_counts, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20]
                        ),
                        "refunds": bucketise(refund_counts, [0, 1, 2, 3, 4]),
                        "changes": bucketise(change_counts, range(10)),
                        "created_week": export_intervals(
                            first_changes, column("created"), "week", "YYYY-MM-DD"
                        ),
                        "active_time": bucketise(
                            [r.active_time for r in active_times], time_buckets
                        ),
                        "amounts": bucketise(
                            cls.query.with_entities(cls.amount_int / 100),
                            [0, 10, 20, 30, 40, 50, 100, 150, 200],
                        ),
                    }
                }
            },
            "tables": ["payment", "payment_version"],
        }

        count_attrs = ["state", "reminder_sent", "currency"]
        data["public"]["payments"]["counts"].update(
            export_attr_counts(cls, count_attrs)
        )

        return data
예제 #33
0
def downgrade():
    op.drop_index(op.f('ix_service_pair_service1'), table_name='service_pair')
    op.drop_index(op.f('ix_service_pair_service0'), table_name='service_pair')
    op.drop_index(op.f('ix_service_pair_direction1'), table_name='service_pair')
    op.drop_index(op.f('ix_service_pair_direction0'), table_name='service_pair')
    op.drop_table('service_pair')

    op.alter_column('special_period', 'date_start', existing_type=sa.DATE(), nullable=True)
    op.alter_column('special_period', 'date_end', existing_type=sa.DATE(), nullable=True)
    op.create_check_constraint('operating_period_check', 'operating_period', sa.column('date_start') <= sa.column('date_end'))
    op.create_check_constraint('special_period_check', 'special_period', sa.column('date_start') <= sa.column('date_end'))

    op.execute("DROP MATERIALIZED VIEW fts;")
    op.execute("""
        CREATE MATERIALIZED VIEW fts AS
        SELECT 'region' AS table_name,
                region.code AS code,
                region.name AS name,
                NULL AS short_ind,
                NULL AS street,
                NULL AS stop_type,
                NULL AS stop_area_ref,
                NULL AS locality_name,
                NULL AS district_name,
                NULL AS admin_area_ref,
                NULL AS admin_area_name,
                CAST(ARRAY[] AS TEXT[]) AS admin_areas,
                setweight(to_tsvector('english', region.name), 'A') AS vector
        FROM region
        WHERE region.code != 'GB'
        UNION ALL
        SELECT 'admin_area' AS table_name,
               admin_area.code AS code,
               admin_area.name AS name,
               NULL AS short_ind,
               NULL AS street,
               NULL AS stop_type,
               NULL AS stop_area_ref,
               NULL AS locality_name,
               NULL AS district_name,
               admin_area.code AS admin_area_ref,
               admin_area.name AS admin_area_name,
               ARRAY[admin_area.code] AS admin_areas,
               setweight(to_tsvector('english', admin_area.name), 'A') AS vector
        FROM admin_area
        WHERE admin_area.region_ref != 'GB'
        UNION ALL
        SELECT 'district' AS table_name,
               district.code AS code,
               district.name AS name,
               NULL AS short_ind,
               NULL AS street,
               NULL AS stop_type,
               NULL AS stop_area_ref,
               NULL AS locality_name,
               NULL AS district_name,
               admin_area.code AS admin_area_ref,
               admin_area.name AS admin_area_name,
               ARRAY[admin_area.code] AS admin_areas,
               setweight(to_tsvector('english', district.name), 'A') ||
               setweight(to_tsvector('english', admin_area.name), 'B') AS vector
        FROM district
             JOIN admin_area ON admin_area.code = district.admin_area_ref
        UNION ALL
        SELECT 'locality' AS table_name,
               locality.code AS code,
               locality.name AS name,
               NULL AS short_ind,
               NULL AS street,
               NULL AS stop_type,
               NULL AS stop_area_ref,
               NULL AS locality_name,
               district.name AS district_name,
               admin_area.code AS admin_area_ref,
               admin_area.name AS admin_area_name,
               ARRAY[admin_area.code] AS admin_areas,
               setweight(to_tsvector('english', locality.name), 'A') ||
               setweight(to_tsvector('english', coalesce(district.name, '')), 'B') ||
               setweight(to_tsvector('english', admin_area.name), 'B') AS vector
        FROM locality
             LEFT OUTER JOIN district ON district.code = locality.district_ref
             JOIN admin_area ON admin_area.code = locality.admin_area_ref
        WHERE EXISTS (
                  SELECT stop_point.atco_code
                  FROM stop_point
                  WHERE stop_point.locality_ref = locality.code
              )
        UNION ALL
        SELECT 'stop_area' AS table_name,
               stop_area.code AS code,
               stop_area.name AS name,
               CAST(count(stop_point.atco_code) AS TEXT) AS short_ind,
               NULL AS street,
               stop_area.stop_area_type AS stop_type,
               NULL AS stop_area_ref,
               locality.name AS locality_name,
               district.name AS district_name,
               admin_area.code AS admin_area_ref,
               admin_area.name AS admin_area_name,
               ARRAY[admin_area.code] AS admin_areas,
               setweight(to_tsvector('english', stop_area.name), 'A') ||
               setweight(to_tsvector('english', coalesce(locality.name, '')), 'C') ||
               setweight(to_tsvector('english', coalesce(district.name, '')), 'D') ||
               setweight(to_tsvector('english', admin_area.name), 'D') AS vector
        FROM stop_area
             LEFT OUTER JOIN stop_point ON stop_area.code = stop_point.stop_area_ref
             LEFT OUTER JOIN locality ON locality.code = stop_area.locality_ref
             LEFT OUTER JOIN district ON district.code = locality.district_ref
             JOIN admin_area ON admin_area.code = stop_area.admin_area_ref
        GROUP BY stop_area.code, locality.name, district.name, admin_area.code
        UNION ALL
        SELECT 'stop_point' AS table_name,
               stop_point.atco_code AS code,
               stop_point.name AS name,
               stop_point.short_ind AS short_ind,
               stop_point.street AS street,
               stop_point.stop_type AS stop_type,
               stop_point.stop_area_ref AS stop_area_ref,
               locality.name AS locality_name,
               district.name AS district_name,
               admin_area.code AS admin_area_ref,
               admin_area.name AS admin_area_name,
               ARRAY[admin_area.code] AS admin_areas,
               setweight(to_tsvector('english', stop_point.name), 'A') ||
               setweight(to_tsvector('english', stop_point.street), 'B') ||
               setweight(to_tsvector('english', locality.name), 'C') ||
               setweight(to_tsvector('english', coalesce(district.name, '')), 'D') ||
               setweight(to_tsvector('english', admin_area.name), 'D') AS vector
        FROM stop_point
             JOIN locality ON locality.code = stop_point.locality_ref
             LEFT OUTER JOIN district ON district.code = locality.district_ref
             JOIN admin_area ON admin_area.code = stop_point.admin_area_ref
        UNION ALL
        SELECT 'service' AS table_name,
               CAST(service.id AS TEXT) AS code,
               service.description AS name,
               service.line AS short_ind,
               NULL AS street,
               NULL AS stop_type,
               NULL AS stop_area_ref,
               NULL AS locality_name,
               NULL AS district_name,
               NULL AS admin_area_ref,
               NULL AS admin_area_name,
               array_agg(DISTINCT admin_area.code) AS admin_areas,
               setweight(to_tsvector('english', service.line), 'A') ||
               setweight(to_tsvector('english', service.description), 'A') ||
               setweight(to_tsvector('english', string_agg(DISTINCT operator.name, ' ')), 'B') ||
               setweight(to_tsvector('english', string_agg(DISTINCT locality.name, ' ')), 'C') ||
               setweight(to_tsvector('english', coalesce(string_agg(DISTINCT district.name, ' '), '')), 'D') ||
               setweight(to_tsvector('english', string_agg(DISTINCT admin_area.name, ' ')), 'D') AS vector
        FROM service
             JOIN journey_pattern ON journey_pattern.service_ref = service.id
             JOIN local_operator ON local_operator.code = journey_pattern.local_operator_ref AND
                                    local_operator.region_ref = journey_pattern.region_ref
             JOIN operator ON local_operator.operator_ref = operator.code
             JOIN journey_link ON journey_pattern.id = journey_link.pattern_ref
             JOIN stop_point ON journey_link.stop_point_ref = stop_point.atco_code
             JOIN locality ON stop_point.locality_ref = locality.code
             LEFT OUTER JOIN district ON locality.district_ref = district.code
             JOIN admin_area ON locality.admin_area_ref = admin_area.code
        GROUP BY service.id
        WITH NO DATA;
    """)
    op.create_index(op.f("ix_fts_table"), "fts", ["table_name"], unique=False)
    op.create_index(op.f("ix_fts_code"), "fts", ["code"], unique=False)
    op.create_index(op.f("ix_fts_unique"), "fts", ["table_name", "code"], unique=True)
    op.create_index(op.f("ix_fts_vector_gin"), "fts", ["vector"], unique=False, postgresql_using="gin")
    op.create_index(op.f("ix_fts_areas_gin"), "fts", ["admin_areas"], unique=False, postgresql_using="gin")

    op.drop_index(op.f('ix_stop_point_active'), table_name='stop_point')
    op.drop_index(op.f('ix_stop_area_active'), table_name='stop_point')
    op.drop_column('stop_point', 'active')
    op.drop_column('stop_area', 'active')

    op.drop_constraint(op.f('journey_pattern_region_ref_fkey'), 'journey_pattern', type_='foreignkey')
def upgrade():
    survey_group_table = op.create_table(
        'surveygroup', sa.Column('id', GUID(), nullable=False),
        sa.Column('title', sa.Text(), nullable=False),
        sa.Column('description', sa.Text()),
        sa.Column('created', sa.DateTime(), nullable=False),
        sa.Column('deleted', sa.Boolean(), nullable=False),
        sa.PrimaryKeyConstraint('id'))
    op.create_index('surveygroup_title_key',
                    'surveygroup', ['title'],
                    unique=True)

    op.create_table(
        'organisation_surveygroup',
        sa.Column('organisation_id', GUID, ForeignKey('organisation.id')),
        sa.Column('surveygroup_id', GUID, ForeignKey('surveygroup.id')),
        Index('organisation_surveygroup_organisation_id_index',
              'organisation_id'),
        Index('organisation_surveygroup_surveygroup_id_index',
              'surveygroup_id'),
    )

    op.create_table(
        'user_surveygroup',
        sa.Column('user_id', GUID, ForeignKey('appuser.id')),
        sa.Column('surveygroup_id', GUID, ForeignKey('surveygroup.id')),
        Index('user_surveygroup_organisation_id_index', 'user_id'),
        Index('user_surveygroup_surveygroup_id_index', 'surveygroup_id'),
    )

    op.create_table(
        'program_surveygroup',
        sa.Column('program_id', GUID, ForeignKey('program.id')),
        sa.Column('surveygroup_id', GUID, ForeignKey('surveygroup.id')),
        Index('program_surveygroup_program_id_index', 'program_id'),
        Index('program_surveygroup_surveygroup_id_index', 'surveygroup_id'),
    )

    op.create_table(
        'activity_surveygroup',
        sa.Column('activity_id', GUID, ForeignKey('activity.id')),
        sa.Column('surveygroup_id', GUID, ForeignKey('surveygroup.id')),
        Index('activity_surveygroup_activity_id_index', 'activity_id'),
        Index('activity_surveygroup_surveygroup_id_index', 'surveygroup_id'),
    )

    op.create_table(
        'id_map',
        sa.Column('old_id', GUID, nullable=False, primary_key=True),
        sa.Column('new_id', GUID, nullable=False),
    )

    ob_types = array([
        'surveygroup', 'organisation', 'user', 'program', 'survey', 'qnode',
        'measure', 'response_type', 'submission', 'rnode', 'response',
        'custom_query'
    ],
                     type_=TEXT)
    op.drop_constraint('activity_ob_type_check', 'activity', type_='check')
    op.create_check_constraint(
        'activity_ob_type_check', 'activity',
        cast(column('ob_type'), TEXT) == func.any(ob_types))
    op.drop_constraint('subscription_ob_type_check',
                       'subscription',
                       type_='check')
    op.create_check_constraint(
        'subscription_ob_type_check', 'subscription',
        cast(column('ob_type'), TEXT) == func.any(ob_types))

    roles = array([
        'super_admin', 'admin', 'author', 'authority', 'consultant',
        'org_admin', 'clerk'
    ])
    op.drop_constraint('appuser_role_check', 'appuser', type_='check')
    op.create_check_constraint('appuser_role_check', 'appuser',
                               cast(column('role'), TEXT) == func.any(roles))

    group_id = GUID.gen()
    op.bulk_insert(survey_group_table, [
        {
            'id': group_id,
            'title': "DEFAULT SURVEY GROUP",
            'created': datetime.datetime.now(),
            'deleted': False,
        },
    ])
    op.execute("""
        INSERT INTO organisation_surveygroup
        (organisation_id, surveygroup_id)
        SELECT organisation.id, '%s'
        FROM organisation
    """ % group_id)
    op.execute("""
        INSERT INTO user_surveygroup
        (user_id, surveygroup_id)
        SELECT appuser.id, '%s'
        FROM appuser
    """ % group_id)
    op.execute("""
        INSERT INTO program_surveygroup
        (program_id, surveygroup_id)
        SELECT program.id, '%s'
        FROM program
    """ % group_id)
    op.execute("""
        INSERT INTO activity_surveygroup
        (activity_id, surveygroup_id)
        SELECT activity.id, '%s'
        FROM activity
    """ % group_id)

    op.execute("GRANT SELECT ON organisation_surveygroup TO analyst")
    op.execute("GRANT SELECT ON activity_surveygroup TO analyst")
    op.execute("GRANT SELECT ON user_surveygroup TO analyst")
    op.execute("GRANT SELECT ON program_surveygroup TO analyst")
    op.execute("GRANT SELECT ON id_map TO analyst")
예제 #35
0
    def get_run_groups(self, filters=None, cursor=None, limit=None):
        # The runs that would be returned by calling RunStorage.get_runs with the same arguments
        runs = self._runs_query(
            filters=filters, cursor=cursor, limit=limit, columns=['run_body', 'run_id']
        ).alias('runs')

        # Gets us the run_id and associated root_run_id for every run in storage that is a
        # descendant run of some root
        #
        # pseudosql:
        #   with all_descendant_runs as (
        #     select *
        #     from run_tags
        #     where key = @ROOT_RUN_ID_TAG
        #   )

        all_descendant_runs = (
            db.select([RunTagsTable])
            .where(RunTagsTable.c.key == ROOT_RUN_ID_TAG)
            .alias('all_descendant_runs')
        )

        # Augment the runs in our query, for those runs that are the descendant of some root run,
        # with the root_run_id
        #
        # pseudosql:
        #
        #   with runs_augmented as (
        #     select
        #       runs.run_id as run_id,
        #       all_descendant_runs.value as root_run_id
        #     from runs
        #     left outer join all_descendant_runs
        #       on all_descendant_runs.run_id = runs.run_id
        #   )

        runs_augmented = (
            db.select(
                [runs.c.run_id.label('run_id'), all_descendant_runs.c.value.label('root_run_id'),]
            )
            .select_from(
                runs.join(
                    all_descendant_runs,
                    all_descendant_runs.c.run_id == RunsTable.c.run_id,
                    isouter=True,
                )
            )
            .alias('runs_augmented')
        )

        # Get all the runs our query will return. This includes runs as well as their root runs.
        #
        # pseudosql:
        #
        #    with runs_and_root_runs as (
        #      select runs.run_id as run_id
        #      from runs, runs_augmented
        #      where
        #        runs.run_id = runs_augmented.run_id or
        #        runs.run_id = runs_augmented.root_run_id
        #    )

        runs_and_root_runs = (
            db.select([RunsTable.c.run_id.label('run_id')])
            .select_from(runs_augmented)
            .where(
                db.or_(
                    RunsTable.c.run_id == runs_augmented.c.run_id,
                    RunsTable.c.run_id == runs_augmented.c.root_run_id,
                )
            )
            .distinct(RunsTable.c.run_id)
        ).alias('runs_and_root_runs')

        # We count the descendants of all of the runs in our query that are roots so that
        # we can accurately display when a root run has more descendants than are returned by this
        # query and afford a drill-down. This might be an unnecessary complication, but the
        # alternative isn't obvious -- we could go and fetch *all* the runs in any group that we're
        # going to return in this query, and then append those.
        #
        # pseudosql:
        #
        #    select runs.run_body, count(all_descendant_runs.id) as child_counts
        #    from runs
        #    join runs_and_root_runs on runs.run_id = runs_and_root_runs.run_id
        #    left outer join all_descendant_runs
        #      on all_descendant_runs.value = runs_and_root_runs.run_id
        #    group by runs.run_body
        #    order by child_counts desc

        runs_and_root_runs_with_descendant_counts = (
            db.select(
                [
                    RunsTable.c.run_body,
                    db.func.count(all_descendant_runs.c.id).label('child_counts'),
                ]
            )
            .select_from(
                RunsTable.join(
                    runs_and_root_runs, RunsTable.c.run_id == runs_and_root_runs.c.run_id
                ).join(
                    all_descendant_runs,
                    all_descendant_runs.c.value == runs_and_root_runs.c.run_id,
                    isouter=True,
                )
            )
            .group_by(RunsTable.c.run_body)
            .order_by(db.desc(db.column('child_counts')))
        )

        with self.connect() as conn:
            res = conn.execute(runs_and_root_runs_with_descendant_counts).fetchall()

        # Postprocess: descendant runs get aggregated with their roots
        run_groups = defaultdict(lambda: {'runs': [], 'count': 0})
        for (run_body, count) in res:
            row = (run_body,)
            pipeline_run = self._row_to_run(row)
            root_run_id = pipeline_run.get_root_run_id()
            if root_run_id is not None:
                run_groups[root_run_id]['runs'].append(pipeline_run)
            else:
                run_groups[pipeline_run.run_id]['runs'].append(pipeline_run)
                run_groups[pipeline_run.run_id]['count'] = count + 1

        return run_groups
예제 #36
0
 def _get_fields(cls, cols):
     return [column(c.get("name")) for c in cols]
예제 #37
0
 def test_time_exp_mixd_case_col_1y(self):
     col = column("MixedCase")
     expr = MssqlEngineSpec.get_timestamp_expr(col, None, "P1Y")
     result = str(expr.compile(None, dialect=mssql.dialect()))
     self.assertEqual(result,
                      "DATEADD(year, DATEDIFF(year, 0, [MixedCase]), 0)")
예제 #38
0
파일: postgres.py 프로젝트: olgermolla/ivre
    def insert_or_update_bulk(self,
                              specs,
                              getinfos=None,
                              separated_timestamps=True):
        """Like `.insert_or_update()`, but `specs` parameter has to be an
        iterable of `(timestamp, spec)` (if `separated_timestamps` is
        True) or `spec` (if it is False) values. This will perform
        PostgreSQL COPY FROM inserts with the major drawback that the
        `getinfos` parameter will be called (if it is not `None`) for
        each spec, even when the spec already exists in the database
        and the call was hence unnecessary.

        It's up to you to decide whether having bulk insert is worth
        it or if you want to go with the regular `.insert_or_update()`
        method.

        """
        more_to_read = True
        tmp = self.create_tmp_table(self.tables.passive)
        if config.DEBUG_DB:
            total_upserted = 0
            total_start_time = time.time()
        while more_to_read:
            if config.DEBUG_DB:
                start_time = time.time()
            with PassiveCSVFile(specs,
                                self.convert_ip,
                                tmp,
                                getinfos=getinfos,
                                separated_timestamps=separated_timestamps,
                                limit=config.POSTGRES_BATCH_SIZE) as fdesc:
                self.copy_from(fdesc, tmp.name)
                more_to_read = fdesc.more_to_read
                if config.DEBUG_DB:
                    count_upserted = fdesc.count
            insrt = postgresql.insert(self.tables.passive)
            self.db.execute(
                insrt.from_select(
                    [column(col) for col in [
                        'addr',
                        # sum / min / max
                        'count', 'firstseen', 'lastseen',
                        # grouped
                        'sensor', 'port', 'recontype', 'source', 'targetval',
                        'value', 'fullvalue', 'info', 'moreinfo'
                    ]],
                    select([tmp.columns['addr'],
                            func.sum_(tmp.columns['count']),
                            func.min_(tmp.columns['firstseen']),
                            func.max_(tmp.columns['lastseen'])] + [
                                tmp.columns[col] for col in [
                                    'sensor', 'port', 'recontype', 'source',
                                    'targetval', 'value', 'fullvalue', 'info',
                                    'moreinfo']])\
                    .group_by(*(tmp.columns[col] for col in [
                        'addr', 'sensor', 'port', 'recontype', 'source',
                        'targetval', 'value', 'fullvalue', 'info', 'moreinfo'
                    ]))
                )\
                .on_conflict_do_update(
                    index_elements=['addr', 'sensor', 'recontype', 'port',
                                    'source', 'value', 'targetval', 'info'],
                    set_={
                        'firstseen': func.least(
                            self.tables.passive.firstseen,
                            insrt.excluded.firstseen,
                        ),
                        'lastseen': func.greatest(
                            self.tables.passive.lastseen,
                            insrt.excluded.lastseen,
                        ),
                        'count':
                        self.tables.passive.count + insrt.excluded.count,
                    },
                )
            )
            self.db.execute(delete(tmp))
            if config.DEBUG_DB:
                stop_time = time.time()
                time_spent = stop_time - start_time
                total_upserted += count_upserted
                total_time_spent = stop_time - total_start_time
                utils.LOGGER.debug(
                    "DB:PERFORMANCE STATS %s upserts, %f s, %s/s\n"
                    "\ttotal: %s upserts, %f s, %s/s",
                    utils.num2readable(count_upserted),
                    time_spent,
                    utils.num2readable(count_upserted / time_spent),
                    utils.num2readable(total_upserted),
                    total_time_spent,
                    utils.num2readable(total_upserted / total_time_spent),
                )
예제 #39
0
    def get_export_data(cls):
        if cls.__name__ == 'Payment':
            # Export stats for each payment type separately
            return {}

        purchase_counts = cls.query.outerjoin(cls.purchases).group_by(
            cls.id).with_entities(func.count(models.Ticket.id))
        refund_counts = cls.query.outerjoin(cls.refunds).group_by(
            cls.id).with_entities(func.count(Refund.id))

        cls_version = version_class(cls)
        cls_transaction = transaction_class(cls)
        changes = cls.query.join(cls.versions).group_by(cls.id)
        change_counts = changes.with_entities(func.count(cls_version.id))
        first_changes = changes.join(cls_version.transaction) \
                               .with_entities(func.min(cls_transaction.issued_at).label('created')) \
                               .from_self()

        cls_ver_new = aliased(cls.versions)
        cls_ver_paid = aliased(cls.versions)
        cls_txn_new = aliased(cls_version.transaction)
        cls_txn_paid = aliased(cls_version.transaction)
        active_time = func.max(cls_txn_paid.issued_at) - func.max(
            cls_txn_new.issued_at)
        active_times = cls.query \
            .join(cls_ver_new, cls_ver_new.id == cls.id) \
            .join(cls_ver_paid, cls_ver_paid.id == cls.id) \
            .join(cls_txn_new, cls_txn_new.id == cls_ver_new.transaction_id) \
            .join(cls_txn_paid, cls_txn_paid.id == cls_ver_paid.transaction_id) \
            .filter(cls_ver_new.state == 'new') \
            .filter(cls_ver_paid.state == 'paid') \
            .with_entities(active_time.label('active_time')) \
            .group_by(cls.id)

        time_buckets = [timedelta(0), timedelta(minutes=1), timedelta(hours=1)] + \
                       [timedelta(d) for d in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 28, 60]]

        data = {
            'public': {
                'payments': {
                    'counts': {
                        'purchases':
                        bucketise(purchase_counts,
                                  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20]),
                        'refunds':
                        bucketise(refund_counts, [0, 1, 2, 3, 4]),
                        'changes':
                        bucketise(change_counts, range(10)),
                        'created_week':
                        export_intervals(first_changes, column('created'),
                                         'week', 'YYYY-MM-DD'),
                        'active_time':
                        bucketise([r.active_time for r in active_times],
                                  time_buckets),
                        'amounts':
                        bucketise(
                            cls.query.with_entities(cls.amount_int / 100),
                            [0, 10, 20, 30, 40, 50, 100, 150, 200]),
                    },
                },
            },
            'tables': ['payment', 'payment_version'],
        }

        count_attrs = ['state', 'reminder_sent', 'currency']
        data['public']['payments']['counts'].update(
            export_attr_counts(cls, count_attrs))

        return data
예제 #40
0
파일: postgres.py 프로젝트: olgermolla/ivre
 def topvalues(self,
               field,
               flt=None,
               topnbr=10,
               sort=None,
               limit=None,
               skip=None,
               least=False):
     """
     This method makes use of the aggregation framework to produce
     top values for a given field or pseudo-field. Pseudo-fields are:
       - category / label / asnum / country / net[:mask]
       - port
       - port:open / :closed / :filtered / :<servicename>
       - portlist:open / :closed / :filtered
       - countports:open / :closed / :filtered
       - service / service:<portnbr>
       - product / product:<portnbr>
       - cpe / cpe.<part> / cpe:<cpe_spec> / cpe.<part>:<cpe_spec>
       - devicetype / devicetype:<portnbr>
       - script:<scriptid> / script:<port>:<scriptid>
         / script:host:<scriptid>
       - cert.* / smb.* / sshkey.*
       - httphdr / httphdr.{name,value} / httphdr:<name>
       - modbus.* / s7.* / enip.*
       - mongo.dbs.*
       - vulns.*
       - screenwords
       - file.* / file.*:scriptid
       - hop
     """
     if flt is None:
         flt = self.flt_empty
     base = flt.query(
         select([self.tables.scan.id
                 ]).select_from(flt.select_from)).cte("base")
     order = "count" if least else desc("count")
     outputproc = None
     if field == "port":
         field = self._topstructure(
             self.tables.port,
             [self.tables.port.protocol, self.tables.port.port],
             self.tables.port.state == "open")
     elif field == "ttl":
         field = self._topstructure(
             self.tables.port,
             [self.tables.port.state_reason_ttl],
             self.tables.port.state_reason_ttl != None,
             # noqa: E711 (BinaryExpression)
         )
     elif field == "ttlinit":
         field = self._topstructure(
             self.tables.port,
             [
                 func.least(
                     255,
                     func.power(
                         2,
                         func.ceil(
                             func.log(2,
                                      self.tables.port.state_reason_ttl))))
             ],
             self.tables.port.state_reason_ttl != None,
             # noqa: E711 (BinaryExpression)
         )
         outputproc = int
     elif field.startswith('port:'):
         info = field[5:]
         field = self._topstructure(
             self.tables.port,
             [self.tables.port.protocol, self.tables.port.port],
             (self.tables.port.state == info)
             if info in ['open', 'filtered', 'closed', 'open|filtered'] else
             (self.tables.port.service_name == info),
         )
     elif field.startswith('countports:'):
         info = field[11:]
         return (
             {"count": result[0], "_id": result[1]}
             for result in self.db.execute(
                 select([func.count().label("count"),
                         column('cnt')])
                 .select_from(
                     select([func.count().label('cnt')])
                     .select_from(self.tables.port)
                     .where(and_(
                         self.tables.port.state == info,
                         # self.tables.port.scan.in_(base),
                         exists(
                             select([1])\
                             .select_from(base)\
                             .where(
                                 self.tables.port.scan == base.c.id
                             )
                         ),
                     ))\
                     .group_by(self.tables.port.scan)\
                     .alias('cnt')
                 ).group_by('cnt').order_by(order).limit(topnbr)
             )
         )
     elif field.startswith('portlist:'):
         ### Deux options pour filtrer:
         ###   -1- self.tables.port.scan.in_(base),
         ###   -2- exists(select([1])\
         ###       .select_from(base)\
         ###       .where(
         ###         self.tables.port.scan == base.c.id
         ###       )),
         ###
         ### D'après quelques tests, l'option -1- est plus beaucoup
         ### rapide quand (base) est pas ou peu sélectif, l'option
         ### -2- un peu plus rapide quand (base) est très sélectif
         ###
         ### TODO: vérifier si c'est pareil pour:
         ###  - countports:open
         ###  - tous les autres
         info = field[9:]
         return (
             {
                 "count":
                 result[0],
                 "_id": [(proto, int(port)) for proto, port in (
                     elt.split(',')
                     for elt in result[1][3:-3].split(')","('))]
             } for result in self.db.execute(
                 select([func.count().label("count"),
                         column('ports')]).
                 select_from(
                     select([
                         func.array_agg(
                             postgresql.aggregate_order_by(
                                 tuple_(self.tables.port.protocol, self.
                                        tables.port.port).label('a'),
                                 tuple_(
                                     self.tables.port.protocol, self.tables.
                                     port.port).label('a'))).label('ports'),
                     ]).where(
                         and_(
                             self.tables.port.state == info,
                             self.tables.port.scan.in_(
                                 base),
                             # exists(select([1])\
                             #        .select_from(base)\
                             #        .where(
                             #            self.tables.port.scan == base.c.id
                             #        )),
                         )).group_by(self.tables.port.scan).alias('ports')
                 ).group_by('ports').order_by(order).limit(topnbr)))
     elif field == "service":
         field = self._topstructure(self.tables.port,
                                    [self.tables.port.service_name],
                                    self.tables.port.state == "open")
     elif field.startswith("service:"):
         info = field[8:]
         if '/' in info:
             info = info.split('/', 1)
             field = self._topstructure(
                 self.tables.port,
                 [self.tables.port.service_name],
                 and_(self.tables.port.protocol == info[0],
                      self.tables.port.port == int(info[1])),
             )
         else:
             field = self._topstructure(self.tables.port,
                                        [self.tables.port.service_name],
                                        self.tables.port.port == int(info))
     elif field == "product":
         field = self._topstructure(
             self.tables.port,
             [
                 self.tables.port.service_name,
                 self.tables.port.service_product
             ],
             self.tables.port.state == "open",
         )
     elif field.startswith("product:"):
         info = field[8:]
         if info.isdigit():
             info = int(info)
             flt = self.flt_and(flt, self.searchport(info))
             field = self._topstructure(
                 self.tables.port,
                 [
                     self.tables.port.service_name,
                     self.tables.port.service_product
                 ],
                 and_(self.tables.port.state == "open",
                      self.tables.port.port == info),
             )
         elif info.startswith('tcp/') or info.startswith('udp/'):
             info = (info[:3], int(info[4:]))
             flt = self.flt_and(flt,
                                self.searchport(info[1], protocol=info[0]))
             field = self._topstructure(
                 self.tables.port,
                 [
                     self.tables.port.service_name,
                     self.tables.port.service_product
                 ],
                 and_(self.tables.port.state == "open",
                      self.tables.port.port == info[1],
                      self.tables.port.protocol == info[0]),
             )
         else:
             flt = self.flt_and(flt, self.searchservice(info))
             field = self._topstructure(
                 self.tables.port,
                 [
                     self.tables.port.service_name,
                     self.tables.port.service_product
                 ],
                 and_(self.tables.port.state == "open",
                      self.tables.port.service_name == info),
             )
     elif field == "devicetype":
         field = self._topstructure(self.tables.port,
                                    [self.tables.port.service_devicetype],
                                    self.tables.port.state == "open")
     elif field.startswith("devicetype:"):
         info = field[11:]
         if info.isdigit():
             info = int(info)
             flt = self.flt_and(flt, self.searchport(info))
             field = self._topstructure(
                 self.tables.port, [self.tables.port.service_devicetype],
                 and_(self.tables.port.state == "open",
                      self.tables.port.port == info))
         elif info.startswith('tcp/') or info.startswith('udp/'):
             info = (info[:3], int(info[4:]))
             flt = self.flt_and(flt,
                                self.searchport(info[1], protocol=info[0]))
             field = self._topstructure(
                 self.tables.port, [self.tables.port.service_devicetype],
                 and_(self.tables.port.state == "open",
                      self.tables.port.port == info[1],
                      self.tables.port.protocol == info[0]))
         else:
             flt = self.flt_and(flt, self.searchservice(info))
             field = self._topstructure(
                 self.tables.port, [self.tables.port.service_devicetype],
                 and_(self.tables.port.state == "open",
                      self.tables.port.service_name == info))
     elif field == "version":
         field = self._topstructure(
             self.tables.port,
             [
                 self.tables.port.service_name,
                 self.tables.port.service_product,
                 self.tables.port.service_version
             ],
             self.tables.port.state == "open",
         )
     elif field.startswith("version:"):
         info = field[8:]
         if info.isdigit():
             info = int(info)
             flt = self.flt_and(flt, self.searchport(info))
             field = self._topstructure(
                 self.tables.port,
                 [
                     self.tables.port.service_name,
                     self.tables.port.service_product,
                     self.tables.port.service_version
                 ],
                 and_(self.tables.port.state == "open",
                      self.tables.port.port == info),
             )
         elif info.startswith('tcp/') or info.startswith('udp/'):
             info = (info[:3], int(info[4:]))
             flt = self.flt_and(flt,
                                self.searchport(info[1], protocol=info[0]))
             field = self._topstructure(
                 self.tables.port,
                 [
                     self.tables.port.service_name,
                     self.tables.port.service_product,
                     self.tables.port.service_version
                 ],
                 and_(self.tables.port.state == "open",
                      self.tables.port.port == info[1],
                      self.tables.port.protocol == info[0]),
             )
         elif ':' in info:
             info = info.split(':', 1)
             flt = self.flt_and(
                 flt, self.searchproduct(info[1], service=info[0]))
             field = self._topstructure(
                 self.tables.port,
                 [
                     self.tables.port.service_name,
                     self.tables.port.service_product,
                     self.tables.port.service_version
                 ],
                 and_(self.tables.port.state == "open",
                      self.tables.port.service_name == info[0],
                      self.tables.port.service_product == info[1]),
             )
         else:
             flt = self.flt_and(flt, self.searchservice(info))
             field = self._topstructure(
                 self.tables.port,
                 [
                     self.tables.port.service_name,
                     self.tables.port.service_product,
                     self.tables.port.service_version
                 ],
                 and_(self.tables.port.state == "open",
                      self.tables.port.service_name == info),
             )
     elif field == "asnum":
         field = self._topstructure(self.tables.scan,
                                    [self.tables.scan.info["as_num"]])
     elif field == "as":
         field = self._topstructure(self.tables.scan, [
             self.tables.scan.info["as_num"],
             self.tables.scan.info["as_name"]
         ])
     elif field == "country":
         field = self._topstructure(self.tables.scan, [
             self.tables.scan.info["country_code"],
             self.tables.scan.info["country_name"]
         ])
     elif field == "city":
         field = self._topstructure(self.tables.scan, [
             self.tables.scan.info["country_code"],
             self.tables.scan.info["city"]
         ])
     elif field == "net" or field.startswith("net:"):
         info = field[4:]
         info = int(info) if info else 24
         field = self._topstructure(
             self.tables.scan,
             [func.set_masklen(text("scan.addr::cidr"), info)],
         )
     elif field == "script" or field.startswith("script:"):
         info = field[7:]
         if info:
             field = self._topstructure(self.tables.script,
                                        [self.tables.script.output],
                                        self.tables.script.name == info)
         else:
             field = self._topstructure(self.tables.script,
                                        [self.tables.script.name])
     elif field in ["category", "categories"]:
         field = self._topstructure(self.tables.category,
                                    [self.tables.category.name])
     elif field.startswith('cert.'):
         subfield = field[5:]
         field = self._topstructure(
             self.tables.script,
             [self.tables.script.data['ssl-cert'][subfield]],
             and_(self.tables.script.name == 'ssl-cert',
                  self.tables.script.data['ssl-cert'].has_key(
                      subfield)))  # noqa: W601 (BinaryExpression)
     elif field == "source":
         field = self._topstructure(self.tables.scan,
                                    [self.tables.scan.source])
     elif field == "domains":
         field = self._topstructure(
             self.tables.hostname,
             [func.unnest(self.tables.hostname.domains)])
     elif field.startswith("domains:"):
         level = int(field[8:]) - 1
         base1 = (select([
             func.unnest(self.tables.hostname.domains).label("domains")
         ]).where(
             exists(
                 select([1]).select_from(base).where(
                     self.tables.hostname.scan == base.c.id))).cte("base1"))
         return ({
             "count": result[1],
             "_id": result[0]
         } for result in self.db.execute(
             select([base1.c.domains,
                     func.count().label("count")]).where(
                         base1.c.domains.op('~')
                         ('^([^\\.]+\\.){%d}[^\\.]+$' %
                          level)).group_by(base1.c.domains).order_by(
                              order).limit(topnbr)))
     elif field == "hop":
         field = self._topstructure(self.tables.hop,
                                    [self.tables.hop.ipaddr])
     elif field.startswith('hop') and field[3] in ':>':
         ttl = int(field[4:])
         field = self._topstructure(
             self.tables.hop,
             [self.tables.hop.ipaddr],
             (self.tables.hop.ttl > ttl) if field[3] == '>' else
             (self.tables.hop.ttl == ttl),
         )
     elif field == 'file' or (field.startswith('file')
                              and field[4] in '.:'):
         if field.startswith('file:'):
             scripts = field[5:]
             if '.' in scripts:
                 scripts, field = scripts.split('.', 1)
             else:
                 field = 'filename'
             scripts = scripts.split(',')
             flt = (self.tables.script.name == scripts[0] if len(scripts)
                    == 1 else self.tables.script.name.in_(scripts))
         else:
             field = field[5:] or 'filename'
             flt = True
         field = self._topstructure(
             self.tables.script,
             [
                 func.jsonb_array_elements(
                     func.jsonb_array_elements(
                         self.tables.script.data['ls']['volumes']).op('->')
                     ('files')).op('->>')(field).label(field)
             ],
             and_(
                 flt,
                 self.tables.script.data.op('@>')(
                     '{"ls": {"volumes": [{"files": []}]}}'),
             ),
         )
     elif field.startswith('modbus.'):
         subfield = field[7:]
         field = self._topstructure(
             self.tables.script,
             [self.tables.script.data['modbus-discover'][subfield]],
             and_(
                 self.tables.script.name == 'modbus-discover',
                 self.tables.script.data['modbus-discover'].has_key(
                     subfield)),
             # noqa: W601 (BinaryExpression)
         )
     elif field.startswith('s7.'):
         subfield = field[3:]
         field = self._topstructure(
             self.tables.script,
             [self.tables.script.data['s7-info'][subfield]],
             and_(self.tables.script.name == 's7-info',
                  self.tables.script.data['s7-info'].has_key(subfield)),
             # noqa: W601 (BinaryExpression)
         )
     elif field == 'httphdr':
         flt = self.flt_and(flt, self.searchscript(name="http-headers"))
         field = self._topstructure(
             self.tables.script,
             [
                 column("hdr").op('->>')('name').label("name"),
                 column("hdr").op('->>')('value').label("value")
             ],
             self.tables.script.name == 'http-headers',
             [column("name"), column("value")],
             func.jsonb_array_elements(
                 self.tables.script.data['http-headers']).alias('hdr'),
         )
     elif field.startswith('httphdr.'):
         flt = self.flt_and(flt, self.searchscript(name="http-headers"))
         field = self._topstructure(
             self.tables.script,
             [column("hdr").op('->>')(field[8:]).label("topvalue")],
             self.tables.script.name == 'http-headers',
             [column("topvalue")],
             func.jsonb_array_elements(
                 self.tables.script.data['http-headers']).alias('hdr'),
         )
     elif field.startswith('httphdr:'):
         flt = self.flt_and(flt, self.searchhttphdr(name=field[8:].lower()))
         field = self._topstructure(
             self.tables.script,
             [column("hdr").op('->>')("value").label("value")],
             and_(self.tables.script.name == 'http-headers',
                  column("hdr").op('->>')("name") == field[8:].lower()),
             [column("value")],
             func.jsonb_array_elements(
                 self.tables.script.data['http-headers']).alias('hdr'),
         )
     else:
         raise NotImplementedError()
     s_from = {
         self.tables.script:
         join(self.tables.script, self.tables.port),
         self.tables.port:
         self.tables.port,
         self.tables.category:
         join(self.tables.association_scan_category, self.tables.category),
         self.tables.hostname:
         self.tables.hostname,
         self.tables.hop:
         join(self.tables.trace, self.tables.hop),
     }
     where_clause = {
         self.tables.script:
         self.tables.port.scan == base.c.id,
         self.tables.port:
         self.tables.port.scan == base.c.id,
         self.tables.category:
         self.tables.association_scan_category.scan == base.c.id,
         self.tables.hostname:
         self.tables.hostname.scan == base.c.id,
         self.tables.hop:
         self.tables.trace.scan == base.c.id
     }
     if field.base == self.tables.scan:
         req = flt.query(
             select([func.count().label("count")] +
                    field.fields).select_from(
                        self.tables.scan).group_by(*field.fields))
     else:
         req = (select([func.count().label("count")] +
                       field.fields).select_from(s_from[field.base]))
         if field.extraselectfrom is not None:
             req = req.select_from(field.extraselectfrom)
         req = (req.group_by(
             *(field.fields if field.group_by is None else field.group_by
               )).where(
                   exists(
                       select([1]).select_from(base).where(
                           where_clause[field.base]))))
     if field.where is not None:
         req = req.where(field.where)
     if outputproc is None:
         return ({
             "count": result[0],
             "_id": result[1:] if len(result) > 2 else result[1]
         } for result in self.db.execute(req.order_by(order).limit(topnbr)))
     else:
         return ({
             "count":
             result[0],
             "_id":
             outputproc(result[1:] if len(result) > 2 else result[1])
         } for result in self.db.execute(req.order_by(order).limit(topnbr)))
revision = '2a6c63397399'
down_revision = '9fd4589cc82c'
branch_labels = None
depends_on = None

# OLD/NEW values must be different
OLD_GROUP_USERS = 'user'
NEW_GROUP_USERS = 'users'
OLD_GROUP_ADMIN = 'admin'
NEW_GROUP_ADMIN = 'administrators'
OLD_USER_USERS = OLD_GROUP_USERS
OLD_USER_ADMIN = OLD_GROUP_ADMIN

users = sa.table(
    "users",
    sa.column("id", sa.Integer),
    sa.column("user_name", sa.String),
)
groups = sa.table("groups", sa.column("id", sa.Integer),
                  sa.column("group_name", sa.String),
                  sa.column("member_count", sa.Integer))
users_groups = sa.table(
    "users_groups",
    sa.column("user_id", sa.Integer),
    sa.column("group_id", sa.Integer),
)


def get_users_groups(db_session):
    """
    Fetch current db users and groups.
예제 #42
0
 def build_expression(self):
     assert len(self.parameter) == 2
     return sqlalchemy.column(self.column_name).between(
         sqlalchemy.bindparam(self.parameter[0]),
         sqlalchemy.bindparam(self.parameter[1]))
def upgrade():
    op.add_column('patch_port', sa.Column('switch_room_id', sa.Integer(), nullable=True))

    host = sa.table('host', sa.column('room_id', sa.Integer),
                            sa.column('id', sa.Integer),
                            sa.column('name', sa.String))
    switch_port = sa.table('switch_port', sa.column('switch_id', sa.Integer),
                                          sa.column('id', sa.Integer))
    switch = sa.table('switch', sa.column('host_id', sa.Integer),
                                sa.column('name', sa.String))
    patch_port = sa.table('patch_port',
                          sa.column('switch_room_id', sa.Integer),
                          sa.column('switch_port_id', sa.Integer),
                          sa.column('id', sa.Integer))

    # Set patch_port.switch_room_id to patch_port.switch_port.switch.host.id
    op.execute(patch_port.update().values(
        switch_room_id=sa.select([host.c.room_id])
                         .select_from(patch_port.alias("patch_port_subselect")
                                      .join(switch_port, patch_port.c.switch_port_id == switch_port.c.id)
                                      .join(host, switch_port.c.switch_id == host.c.id))
                         .where(sa.literal_column('patch_port_subselect.id') == patch_port.c.id)
    ))

    op.alter_column('patch_port', 'switch_room_id', nullable=False)

    op.create_index(op.f('ix_patch_port_switch_room_id'), 'patch_port', ['switch_room_id'], unique=False)
    op.create_foreign_key("patch_port_switch_room_id_fkey", 'patch_port', 'room', ['switch_room_id'], ['id'])

    # Set switch.host.name to switch.name
    op.execute(host.update().values(
        name=sa.select([switch.c.name])
               .select_from(host.alias("host_subselect")
                            .join(switch, switch.c.host_id == host.c.id))
               .where(sa.literal_column('host_subselect.id') == switch.c.host_id)
    ))

    op.drop_column('switch', 'name')

    # Create patch_port_switch_in_switch_room function and trigger
    op.execute('''
        CREATE OR REPLACE FUNCTION patch_port_switch_in_switch_room() RETURNS trigger STABLE STRICT LANGUAGE plpgsql AS $$
        DECLARE
          v_patch_port patch_port;
          v_switch_port_switch_host_room_id integer;
        BEGIN
          v_patch_port := NEW;

          IF v_patch_port.switch_port_id IS NOT NULL THEN
              SELECT h.room_id INTO v_switch_port_switch_host_room_id FROM patch_port pp
                  JOIN switch_port sp ON pp.switch_port_id = sp.id
                  JOIN host h ON sp.switch_id = h.id
                  WHERE pp.id = v_patch_port.id;

              IF v_switch_port_switch_host_room_id <> v_patch_port.switch_room_id THEN
                RAISE EXCEPTION 'A patch-port can only be patched to a switch that is located in the switch-room of
                                  the patch-port';
              END IF;
          END IF;
          RETURN NULL;
        END;
        $$
    ''')
    op.execute('''
        CREATE CONSTRAINT TRIGGER patch_port_switch_in_switch_room_trigger
        AFTER INSERT OR UPDATE
        ON patch_port
        DEFERRABLE INITIALLY DEFERRED
        FOR EACH ROW EXECUTE PROCEDURE patch_port_switch_in_switch_room()
    ''')

    op.create_unique_constraint("switch_port_name_switch_id_key", 'switch_port', ['name', 'switch_id'])
예제 #44
0
"""
import datetime as dt

from alembic import op
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as sa_pg

# revision identifiers, used by Alembic.
revision = 'a435563da77e'
down_revision = '1466e2297214'
branch_labels = None
depends_on = None

bank_holiday = sa.table("bank_holiday_date",
                        sa.column("holiday_ref", sa.Integer),
                        sa.column("date", sa.Date))

date_f = "%Y-%m-%d"
bank_holiday_dates = [
    {
        "holiday_ref": 1,
        "date": "2020-01-01"
    },
    {
        "holiday_ref": 2,
        "date": "2020-01-02"
    },
    {
        "holiday_ref": 3,
        "date": "2020-04-10"
import sqlalchemy as sa
from alembic import op
from alembic.context import get_context  # noqa: F401
from sqlalchemy.dialects.postgresql.base import PGDialect
from sqlalchemy.orm.session import sessionmaker

# revision identifiers, used by Alembic.
revision = "a395ef9d3fe6"
down_revision = "ae1a3c8c7860"
branch_labels = None
depends_on = None

Session = sessionmaker()

resources = sa.table("resources", sa.column("root_service_id", sa.Integer),
                     sa.column("resource_id", sa.Integer),
                     sa.column("parent_id", sa.Integer))


def upgrade():
    context = get_context()
    session = Session(bind=op.get_bind())

    # two following lines avoids double "DELETE" erroneous call when deleting group due to incorrect checks
    # https://stackoverflow.com/questions/28824401
    context.connection.engine.dialect.supports_sane_rowcount = False
    context.connection.engine.dialect.supports_sane_multi_rowcount = False

    if isinstance(context.connection.engine.dialect, PGDialect):
        op.add_column(
예제 #46
0
def read_sql_table(engine,
                   table_name,
                   index_col=None,
                   columns=None,
                   select_from=None,
                   limit=None,
                   order_by=None,
                   where=None,
                   coerce_types=None,
                   raise_on_missing=True):
    """ Load a table from a SQL database.
    
    Parameters
    ----------
    engine : SQLAlchemy engine
        The SQL database to load from.
    
    table_name : str
        The name of the table to load.
    
    index_col : str, optional
        Column name to use as index for the returned data frame.
    
    columns : sequence of str, optional
        Columns to select from the table. By default, all columns are selected.

    select_from : str or SQLAlchemy clause, optional
        A FROM clause to use for the select statement. Defaults to the
        table name.
    
    limit : int, optional
        Limit the number of rows selected.
    
    order_by : str or SQLAlchemy clause, optional
        An ORDER BY clause to sort the selected rows.
    
    where : str or SQLAlchemy clause, optional
        A WHERE clause used to filter the selected rows.
    
    coerce_types : dict(str : dtype or Python type), optional
        Override pandas type inference for specific columns.
    
    Returns
    -------
    A pandas DataFrame.
    """
    # Pandas does not expose many of these options, so we pull out some of
    # Pandas' internals.
    #
    # An alternative approach would be to use `pandas.read_sql_query` with an
    # appropriate (dialect-specific) query. However, this approach would not
    # utilize Pandas' logic for column type inference (performed by
    # `_harmonize_columns()` below), and would hence produce inferior results.

    from sqlalchemy.schema import MetaData
    from pandas.io.sql import SQLDatabase, SQLTable

    # From pandas.io.sql.read_sql_table
    # and  pandas.io.sql.SQLDatabase.read_table:
    meta = MetaData(engine)
    try:
        meta.reflect(only=[table_name])
    except sqlalchemy.exc.InvalidRequestError:
        if raise_on_missing:
            raise ValueError("Table %s not found" % table_name)
        else:
            return None

    pd_db = SQLDatabase(engine, meta=meta)
    pd_tbl = SQLTable(table_name, pd_db, index=None)

    # Adapted from pandas.io.SQLTable.read:
    if columns is not None and len(columns) > 0:
        if index_col is not None and index_col not in columns:
            columns = [index_col] + columns

        cols = [pd_tbl.table.c[n] for n in columns]
    else:
        cols = pd_tbl.table.c

    if pd_tbl.index is not None:
        [cols.insert(0, pd_tbl.table.c[idx]) for idx in pd_tbl.index[::-1]]

    # Strip the table name from each of the column names to allow for more
    # general FROM clauses.
    sql_select = sqlalchemy.select([
        sqlalchemy.column(str(c).replace('{}.'.format(table_name), '', 1))
        for c in cols
    ])

    if select_from is not None:
        sql_select = sql_select.select_from(select_from)
    else:
        sql_select = sql_select.select_from(sqlalchemy.table(table_name))

    if where is not None:
        if isinstance(where, basestring):
            where = sqlalchemy.text(where)
        sql_select = sql_select.where(where)
    if limit is not None:
        sql_select = sql_select.limit(limit)
    if order_by is not None:
        if isinstance(order_by, basestring):
            order_by = sqlalchemy.sql.column(order_by)
        sql_select = sql_select.order_by(order_by)

    result = pd_db.execute(sql_select)
    data = result.fetchall()
    column_names = result.keys()

    pd_tbl.frame = pandas.DataFrame.from_records(data,
                                                 index=index_col,
                                                 columns=column_names)

    # This line has caused issues with incorrect type inference -- add it
    # back with caution.
    # pd_tbl._harmonize_columns()

    # Added by me: coerce types
    if coerce_types:
        frame = pd_tbl.frame
        for col, dtype in coerce_types.iteritems():
            frame[col] = frame[col].astype(dtype, copy=False)

    return pd_tbl.frame
 def test_oracle_sqla_column_name_length_exceeded(self):
     col = column('This_Is_32_Character_Column_Name')
     label = OracleEngineSpec.make_label_compatible(col.name)
     self.assertEqual(label.quote, True)
     label_expected = '3b26974078683be078219674eeb8f5'
     self.assertEqual(label, label_expected)
예제 #48
0
    def get_compute_domain(
        self,
        domain_kwargs: Dict,
        domain_type: Union[str, "MetricDomainTypes"],
        accessor_keys: Optional[Iterable[str]] = None,
    ) -> Tuple["sa.sql.Selectable", dict, dict]:
        """Uses a given batch dictionary and domain kwargs to obtain a SqlAlchemy column object.

        Args:
            domain_kwargs (dict) - A dictionary consisting of the domain kwargs specifying which data to obtain
            domain_type (str or "MetricDomainTypes") - an Enum value indicating which metric domain the user would
            like to be using, or a corresponding string value representing it. String types include "identity", "column",
            "column_pair", "table" and "other". Enum types include capitalized versions of these from the class
            MetricDomainTypes.
            accessor_keys (str iterable) - keys that are part of the compute domain but should be ignored when describing
            the domain and simply transferred with their associated values into accessor_domain_kwargs.

        Returns:
            SqlAlchemy column
        """
        # Extracting value from enum if it is given for future computation
        domain_type = MetricDomainTypes(domain_type)
        batch_id = domain_kwargs.get("batch_id")
        if batch_id is None:
            # We allow no batch id specified if there is only one batch
            if self.active_batch_data:
                data_object = self.active_batch_data
            else:
                raise GreatExpectationsError(
                    "No batch is specified, but could not identify a loaded batch."
                )
        else:
            if batch_id in self.loaded_batch_data_dict:
                data_object = self.loaded_batch_data_dict[batch_id]
            else:
                raise GreatExpectationsError(
                    f"Unable to find batch with batch_id {batch_id}")

        compute_domain_kwargs = copy.deepcopy(domain_kwargs)
        accessor_domain_kwargs = dict()
        if "table" in domain_kwargs and domain_kwargs["table"] is not None:
            # TODO: Add logic to handle record_set_name once implemented
            # (i.e. multiple record sets (tables) in one batch
            if domain_kwargs["table"] != data_object.selectable.name:
                selectable = sa.Table(
                    domain_kwargs["table"],
                    sa.MetaData(),
                    schema_name=data_object._schema_name,
                )
            else:
                selectable = data_object.selectable
        elif "query" in domain_kwargs:
            raise ValueError(
                "query is not currently supported by SqlAlchemyExecutionEngine"
            )
        else:
            selectable = data_object.selectable

        if ("row_condition" in domain_kwargs
                and domain_kwargs["row_condition"] is not None):
            condition_parser = domain_kwargs["condition_parser"]
            if condition_parser == "great_expectations__experimental__":
                parsed_condition = parse_condition_to_sqlalchemy(
                    domain_kwargs["row_condition"])
                selectable = sa.select("*",
                                       from_obj=selectable,
                                       whereclause=parsed_condition)

            else:
                raise GreatExpectationsError(
                    "SqlAlchemyExecutionEngine only supports the great_expectations condition_parser."
                )

        # Warning user if accessor keys are in any domain that is not of type table, will be ignored
        if (domain_type != MetricDomainTypes.TABLE
                and accessor_keys is not None and len(accessor_keys) > 0):
            logger.warning(
                "Accessor keys ignored since Metric Domain Type is not 'table'"
            )

        if domain_type == MetricDomainTypes.TABLE:
            if accessor_keys is not None and len(accessor_keys) > 0:
                for key in accessor_keys:
                    accessor_domain_kwargs[key] = compute_domain_kwargs.pop(
                        key)
            if len(domain_kwargs.keys()) > 0:
                for key in compute_domain_kwargs.keys():
                    # Warning user if kwarg not "normal"
                    if key not in [
                            "batch_id",
                            "table",
                            "row_condition",
                            "condition_parser",
                    ]:
                        logger.warning(
                            f"Unexpected key {key} found in domain_kwargs for domain type {domain_type.value}"
                        )
            return selectable, compute_domain_kwargs, accessor_domain_kwargs

        # If user has stated they want a column, checking if one is provided, and
        elif domain_type == MetricDomainTypes.COLUMN:
            if "column" in compute_domain_kwargs:
                # Checking if case- sensitive and using appropriate name
                if self.active_batch_data.use_quoted_name:
                    accessor_domain_kwargs["column"] = quoted_name(
                        compute_domain_kwargs.pop("column"))
                else:
                    accessor_domain_kwargs[
                        "column"] = compute_domain_kwargs.pop("column")
            else:
                # If column not given
                raise GreatExpectationsError(
                    "Column not provided in compute_domain_kwargs")

        # Else, if column pair values requested
        elif domain_type == MetricDomainTypes.COLUMN_PAIR:
            # Ensuring column_A and column_B parameters provided
            if ("column_A" in compute_domain_kwargs
                    and "column_B" in compute_domain_kwargs):
                if self.active_batch_data.use_quoted_name:
                    # If case matters...
                    accessor_domain_kwargs["column_A"] = quoted_name(
                        compute_domain_kwargs.pop("column_A"))
                    accessor_domain_kwargs["column_B"] = quoted_name(
                        compute_domain_kwargs.pop("column_B"))
                else:
                    accessor_domain_kwargs[
                        "column_A"] = compute_domain_kwargs.pop("column_A")
                    accessor_domain_kwargs[
                        "column_B"] = compute_domain_kwargs.pop("column_B")
            else:
                raise GreatExpectationsError(
                    "column_A or column_B not found within compute_domain_kwargs"
                )

        # Checking if table or identity or other provided, column is not specified. If it is, warning the user
        elif domain_type == MetricDomainTypes.MULTICOLUMN:
            if "columns" in compute_domain_kwargs:
                # If columns exist
                accessor_domain_kwargs["columns"] = compute_domain_kwargs.pop(
                    "columns")

        # Filtering if identity
        elif domain_type == MetricDomainTypes.IDENTITY:
            # If we would like our data to become a single column
            if "column" in compute_domain_kwargs:
                if self.active_batch_data.use_quoted_name:
                    selectable = sa.select([
                        sa.column(quoted_name(compute_domain_kwargs["column"]))
                    ]).select_from(selectable)
                else:
                    selectable = sa.select([
                        sa.column(compute_domain_kwargs["column"])
                    ]).select_from(selectable)

            # If we would like our data to now become a column pair
            elif ("column_A"
                  in compute_domain_kwargs) and ("column_B"
                                                 in compute_domain_kwargs):
                if self.active_batch_data.use_quoted_name:
                    selectable = sa.select([
                        sa.column(
                            quoted_name(compute_domain_kwargs["column_A"])),
                        sa.column(
                            quoted_name(compute_domain_kwargs["column_B"])),
                    ]).select_from(selectable)
                else:
                    selectable = sa.select([
                        sa.column(compute_domain_kwargs["column_A"]),
                        sa.column(compute_domain_kwargs["column_B"]),
                    ]).select_from(selectable)
            else:
                # If we would like our data to become a multicolumn
                if "columns" in compute_domain_kwargs:
                    if self.active_batch_data.use_quoted_name:
                        # Building a list of column objects used for sql alchemy selection
                        to_select = [
                            sa.column(quoted_name(col))
                            for col in compute_domain_kwargs["columns"]
                        ]
                        selectable = sa.select(to_select).select_from(
                            selectable)
                    else:
                        to_select = [
                            sa.column(col)
                            for col in compute_domain_kwargs["columns"]
                        ]
                        selectable = sa.select(to_select).select_from(
                            selectable)

        # Letting selectable fall through
        return selectable, compute_domain_kwargs, accessor_domain_kwargs
예제 #49
0
class TestReport(reports.Report):
    """
    Writes report to the database.
    """
    import sqlalchemy as sa
    SUPPORTED_EXPERIMENTS = ['pj-test']
    SCHEMA = [papi.Event.__table__]

    QUERY_TOTAL = \
        sa.sql.select([
            sa.column('project'),
            sa.column('domain'),
            sa.column('speedup'),
            sa.column('ohcov_0'),
            sa.column('ohcov_1'),
            sa.column('dyncov_0'),
            sa.column('dyncov_1'),
            sa.column('cachehits_0'),
            sa.column('cachehits_1'),
            sa.column('variants_0'),
            sa.column('variants_1'),
            sa.column('codegen_0'),
            sa.column('codegen_1'),
            sa.column('scops_0'),
            sa.column('scops_1'),
            sa.column('t_0'),
            sa.column('o_0'),
            sa.column('t_1'),
            sa.column('o_1')
        ]).\
        select_from(
            sa.func.pj_test_eval(sa.sql.bindparam('exp_ids'))
        )

    QUERY_REGION = \
        sa.sql.select([
            sa.column('project'),
            sa.column('region'),
            sa.column('cores'),
            sa.column('t_polly'),
            sa.column('t_polyjit'),
            sa.column('speedup')
        ]).\
        select_from(
            sa.func.pj_test_region_wise(sa.sql.bindparam('exp_ids'))
        )

    def report(self):
        print("I found the following matching experiment ids")
        print("  \n".join([str(x) for x in self.experiment_ids]))

        qry = TestReport.QUERY_TOTAL.unique_params(exp_ids=self.experiment_ids)
        yield ("complete",
               ('project', 'domain', 'speedup', 'ohcov_0', 'ocov_1',
                'dyncov_0', 'dyncov_1', 'cachehits_0', 'cachehits_1',
                'variants_0', 'variants_1', 'codegen_0', 'codegen_1',
                'scops_0', 'scops_1', 't_0', 'o_0', 't_1', 'o_1'),
               self.session.execute(qry).fetchall())
        qry = TestReport.QUERY_REGION.unique_params(
            exp_ids=self.experiment_ids)
        yield ("regions", ('project', 'region', 'cores',
                           'T_Polly', 'T_PolyJIT', 'speedup'),
               self.session.execute(qry).fetchall())

    def generate(self):
        for name, header, data in self.report():
            fname = os.path.basename(self.out_path)

            fname = "{prefix}_{name}{ending}".format(
                prefix=os.path.splitext(fname)[0],
                ending=os.path.splitext(fname)[-1],
                name=name)
            with open(fname, 'w') as csv_out:
                print("Writing '{0}'".format(csv_out.name))
                csv_writer = csv.writer(csv_out)
                csv_writer.writerows([header])
                csv_writer.writerows(data)
예제 #50
0
 def test_oracle_time_expression_reserved_keyword_1m_grain(self):
     col = column("decimal")
     expr = OracleEngineSpec.get_timestamp_expr(col, None, "P1M")
     result = str(expr.compile(dialect=oracle.dialect()))
     self.assertEqual(result, "TRUNC(CAST(\"decimal\" as DATE), 'MONTH')")
예제 #51
0
def messages():
    if not db.engine.dialect.has_table(db.engine, "messages"):
        db.create_all()

    columns = Message.__table__.columns.keys()
    columns = [i for i in columns if not i.startswith('_')]

    columns_exclude = ['deleted']
    required_columns = [i for i in columns if i not in columns_exclude]

    draw = request.values.get('draw')
    if draw:
        start = request.values.get('start', '')
        length = request.values.get('length', '')
        _ = request.values.get('_')

        requested_params = parse_multi_form(
            request.values)  # columns, order, search

        start = int(start) if start.isdigit() else 0
        length = int(length) if length.isdigit() else limit_default

        s = db.session.query(Message).filter_by(deleted=False)
        total = s.count()

        column_params = requested_params.get('columns', {})
        column_names = dict([(i, column_params.get(i, {}).get('data'))
                             for i in column_params.keys()])
        column_searchables = dict([(i, column_params.get(i,
                                                         {}).get('searchable'))
                                   for i in column_params.keys()])
        column_searchables = [
            column_names.get(k) for k, v in column_searchables.items()
            if v == 'true'
        ]

        search_params = requested_params.get('search', {})
        search = search_params.get('value')
        regex = search_params.get('regex')

        criterion = or_(*[column(i).like("%{0}%".format(search)) for i in column_searchables]) \
            if search else None

        if criterion is None:
            filtered = total
        else:
            s = s.filter(criterion)
            filtered = s.count()

        order_params = requested_params.get('order', {})
        order = []
        for i in sorted(order_params.keys()):
            column_sort_dict = order_params.get(i)
            column_id = int(column_sort_dict.get('column', ''))
            sort_dir = column_sort_dict.get('dir', '')
            sort_column = column_names.get(column_id)
            if sort_column:
                c = desc(column(sort_column)
                         ) if sort_dir == 'desc' else column(sort_column)
                order.append(c)

        if order:
            s = s.order_by(*order)

        if start and start < filtered:
            s = s.offset(start)
        if length:
            s = s.limit(length)

        rows = s.all()
        rows = [
            dict(row_iter(required_columns, row, start + j))
            for j, row in enumerate(rows, 1)
        ]

        return render_ext(
            format="json",
            draw=
            draw,  # Переменная получена из запроса, но не используется для вывода (потенциально безопасная)
            recordsTotal=total,
            recordsFiltered=filtered,
            data=
            rows,  # Данные полученные из запроса, необходимо маскировать (потенциально опасная) - маскировка в row_iter
        )

    required_columns.remove('id')

    columns_dictionary = dict(
        name="Name",
        author="Author",
        message="Message",
        created="Created",
        updated="Updated",
    )
    names = [columns_dictionary.get(i) or i for i in required_columns] + ['']

    return render_template(
        "message/messages.html",
        title="Messages",
        required_columns=required_columns,
        names=names,
        seq=True,
        column_info={
            "url": ["id", "fa fa-info", "fa fa-info",
                    url_for('message_info')]
        },
        columns_extra=[
            [
                "id", "fa fa-edit", "", "fa fa-edit", "",
                url_for('message_edit')
            ],
        ],
    )
예제 #52
0
파일: base.py 프로젝트: fred2305/superset
 def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[Any]:
     return [column(c["name"]) for c in cols]
예제 #53
0
 def test_pg_time_expression_mixed_case_column_1y_grain(self):
     col = column("MixedCase")
     expr = PostgresEngineSpec.get_timestamp_expr(col, None, "P1Y")
     result = str(expr.compile(dialect=postgresql.dialect()))
     self.assertEqual(result, "DATE_TRUNC('year', \"MixedCase\")")
예제 #54
0
"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = '1466e2297214'
down_revision = '74069f6388f3'
branch_labels = None
depends_on = None


stop_area = sa.table(
    "stop_area",
    sa.column("code", sa.VARCHAR(12)),
    sa.column("active", sa.Boolean)
)
stop_point = sa.table(
    "stop_point",
    sa.column("atco_code", sa.VARCHAR(12)),
    sa.column("active", sa.Boolean)
)


def upgrade():
    op.create_foreign_key(op.f('journey_pattern_region_ref_fkey'), 'journey_pattern', 'region', ['region_ref'], ['code'], ondelete='CASCADE')

    op.add_column('stop_area', sa.Column('active', sa.Boolean(), nullable=True))
    op.add_column('stop_point', sa.Column('active', sa.Boolean(), nullable=True))
    op.execute(sa.update(stop_area).values(active=True))
예제 #55
0
 def test_unicode_text_literal_binds(self):
     self.assert_compile(
         column("x", UnicodeText()) == "foo",
         "x = N'foo'",
         literal_binds=True,
     )
예제 #56
0
 def test_string_text_literal_binds_explicit_unicode_right(self):
     self.assert_compile(
         column("x", String()) == util.u("foo"),
         "x = 'foo'",
         literal_binds=True,
     )
예제 #57
0
 def test_string_text_literal_binds(self):
     self.assert_compile(
         column("x", String()) == "foo", "x = 'foo'", literal_binds=True
     )
예제 #58
0
 def build_expression(self):
     return sqlalchemy.column(self.column_name).is_distinct_from(
         sqlalchemy.bindparam(self.parameter))
예제 #59
0
    def get_alarms_count(self,
                         tenant_id,
                         query_parms=None,
                         offset=None,
                         limit=None):
        if not query_parms:
            query_parms = {}

        with self._db_engine.connect() as conn:
            parms = {}
            ad = self.ad
            am = self.am
            mdd = self.mdd
            mde = self.mde
            md = self.md
            a = self.a

            query_from = a.join(ad, ad.c.id == a.c.alarm_definition_id)

            parms['b_tenant_id'] = tenant_id

            group_by_columns = []

            if 'group_by' in query_parms:
                group_by_columns = query_parms['group_by']
                sub_group_by_columns = []
                metric_group_by = {
                    'metric_name', 'dimension_name', 'dimension_value'
                }.intersection(set(query_parms['group_by']))
                if metric_group_by:
                    sub_query_columns = [am.c.alarm_id]
                    if 'metric_name' in metric_group_by:
                        sub_group_by_columns.append(
                            mde.c.name.label('metric_name'))
                    if 'dimension_name' in metric_group_by:
                        sub_group_by_columns.append(
                            md.c.name.label('dimension_name'))
                    if 'dimension_value' in metric_group_by:
                        sub_group_by_columns.append(
                            md.c.value.label('dimension_value'))

                    sub_query_columns.extend(sub_group_by_columns)

                    sub_query_from = (mde.join(
                        mdd, mde.c.id == mdd.c.metric_definition_id).join(
                            md, mdd.c.metric_dimension_set_id ==
                            md.c.dimension_set_id).join(
                                am, am.c.metric_definition_dimensions_id ==
                                mdd.c.id))

                    sub_query = (select(sub_query_columns).select_from(
                        sub_query_from).distinct().alias('metrics'))

                    query_from = query_from.join(
                        sub_query, sub_query.c.alarm_id == a.c.id)

            query_columns = [func.count().label('count')]
            query_columns.extend([column(col) for col in group_by_columns])

            query = (select(query_columns).select_from(query_from).where(
                ad.c.tenant_id == bindparam('b_tenant_id')))

            parms['b_tenant_id'] = tenant_id

            if 'alarm_definition_id' in query_parms:
                parms['b_alarm_definition_id'] = query_parms[
                    'alarm_definition_id']
                query = query.where(
                    ad.c.id == bindparam('b_alarm_definition_id'))

            if 'state' in query_parms:
                parms['b_state'] = query_parms['state'] if six.PY3 else \
                    query_parms['state'].encode('utf8')
                query = query.where(a.c.state == bindparam('b_state'))

            if 'severity' in query_parms:
                severities = query_parms['severity'].split('|')
                query = query.where(
                    or_(ad.c.severity == bindparam('b_severity' + str(i))
                        for i in range(len(severities))))
                for i, s in enumerate(severities):
                    parms['b_severity' +
                          str(i)] = s if six.PY3 else s.encode('utf8')

            if 'lifecycle_state' in query_parms:
                parms['b_lifecycle_state'] = query_parms['lifecycle_state'] if six.PY3 else \
                    query_parms['lifecycle_state'].encode('utf8')
                query = query.where(
                    a.c.lifecycle_state == bindparam('b_lifecycle_state'))

            if 'link' in query_parms:
                parms['b_link'] = query_parms['link'] if six.PY3 else \
                    query_parms['link'].encode('utf8')
                query = query.where(a.c.link == bindparam('b_link'))

            if 'state_updated_start_time' in query_parms:
                date_str = query_parms['state_updated_start_time'] if six.PY3 \
                    else query_parms['state_updated_start_time'].encode('utf8')
                date_param = datetime.strptime(date_str,
                                               '%Y-%m-%dT%H:%M:%S.%fZ')
                parms['b_state_updated_at'] = date_param
                query = query.where(
                    a.c.state_updated_at >= bindparam('b_state_updated_at'))

            if 'metric_name' in query_parms:
                query = query.where(a.c.id.in_(self.get_a_am_query))
                parms['b_md_name'] = query_parms['metric_name'] if six.PY3 else \
                    query_parms['metric_name'].encode('utf8')

            if 'metric_dimensions' in query_parms:
                sub_query = select([a.c.id])
                sub_query_from = (a.join(am, am.c.alarm_id == a.c.id).join(
                    mdd, mdd.c.id == am.c.metric_definition_dimensions_id))

                sub_query_md_base = select([md.c.dimension_set_id
                                            ]).select_from(md)

                for i, metric_dimension in enumerate(
                        query_parms['metric_dimensions'].items()):
                    dimension_value = metric_dimension[1] if six.PY3 else \
                        metric_dimension[1].encode('utf8')

                    if '|' in dimension_value:
                        dimension_value = tuple(dimension_value.split('|'))

                    md_name = "b_md_name_{}".format(i)
                    md_value = "b_md_value_{}".format(i)

                    sub_query_md = (sub_query_md_base.where(
                        md.c.name == bindparam(md_name)))

                    if isinstance(dimension_value, tuple):
                        sub_query_md = (sub_query_md.where(
                            md.c.value.op('IN')(bindparam(md_value))))
                    else:
                        sub_query_md = (sub_query_md.where(
                            md.c.value == bindparam(md_value)))

                    sub_query_md = (sub_query_md.distinct().alias(
                        'md_{}'.format(i)))

                    sub_query_from = (sub_query_from.join(
                        sub_query_md, sub_query_md.c.dimension_set_id ==
                        mdd.c.metric_dimension_set_id))

                    parms[md_name] = metric_dimension[0] if six.PY3 else \
                        metric_dimension[0].encode('utf8')
                    parms[md_value] = dimension_value

                    sub_query = (
                        sub_query.select_from(sub_query_from).distinct())
                    query = query.where(a.c.id.in_(sub_query))

            if group_by_columns:
                query = (query.order_by(*group_by_columns).group_by(
                    *group_by_columns))

            if limit:
                query = query.limit(bindparam('b_limit'))
                parms['b_limit'] = limit + 1

            if offset:
                query = query.offset(bindparam('b_offset'))
                parms['b_offset'] = offset

            query = query.distinct()
            return [dict(row) for row in conn.execute(query, parms).fetchall()]
예제 #60
0
 def test_pg_time_expression_lower_column_no_grain(self):
     col = column("lower_case")
     expr = PostgresEngineSpec.get_timestamp_expr(col, None, None)
     result = str(expr.compile(dialect=postgresql.dialect()))
     self.assertEqual(result, "lower_case")