Example #1
0
def test_orm_order_by_bundle(dburl):
    Scorecard = Bundle(
        "scorecard",
        # CW: existential horror
        Book.score.label("popularity"),
        Book.popularity.label("score"),
    )

    with S(dburl, echo=ECHO) as s:
        q = s.query(Book).order_by(Scorecard, Book.id)
        check_paging_orm(q=q)
        q = s.query(Book, Scorecard).order_by(Book.id)
        check_paging_orm(q=q)
        q = s.query(Scorecard).order_by(Scorecard.c.popularity, Book.id)
        check_paging_orm(q=q)
Example #2
0
def test_createdrop(tmpdir):
    sqlite_path = str(tmpdir / "testonly.db")

    urls = [
        "postgresql:///sqlbag_testonly", "mysql+pymysql:///sqlbag_testonly"
    ]

    for db_url in urls:
        drop_database(db_url)
        assert not drop_database(db_url)
        assert not exists(db_url)
        assert create_database(db_url)
        assert exists(db_url)

        if db_url.startswith("postgres"):
            assert create_database(db_url,
                                   template="template1",
                                   wipe_if_existing=True)
        else:
            assert create_database(db_url, wipe_if_existing=True)
        assert exists(db_url)
        assert drop_database(db_url)
        assert not exists(db_url)

    db_url = "sqlite://"  # in-memory special case

    assert exists(db_url)
    assert not create_database(db_url)
    assert exists(db_url)
    assert not drop_database(db_url)
    assert exists(db_url)

    db_url = "sqlite:///" + sqlite_path

    assert not database_exists(db_url)
    # selecting works because sqlite auto-creates
    assert database_exists(db_url, test_can_select=True)
    drop_database(db_url)
    create_database(db_url)
    assert exists(db_url)

    drop_database(db_url)
    assert not database_exists(db_url)
    assert database_exists(db_url, test_can_select=True)

    with temporary_database("sqlite") as dburi:
        with S(dburi) as s:
            s.execute("select 1")
Example #3
0
def test_orm_joined_inheritance(joined_inheritance_dburl):
    with S(joined_inheritance_dburl, echo=ECHO) as s:
        q = s.query(Animal).order_by(Animal.leg_count, Animal.id)
        check_paging_orm(q=q)

        q = s.query(Vertebrate).order_by(Vertebrate.vertebra_count, Animal.id)
        check_paging_orm(q=q)

        q = s.query(Mammal).order_by(Mammal.nipple_count, Mammal.leg_count, Mammal.id)
        check_paging_orm(q=q)

        # Mix up accessing columns at various heirarchy levels
        q = s.query(Mammal).order_by(
            Mammal.nipple_count, Mammal.leg_count, Vertebrate.vertebra_count, Animal.id
        )
        check_paging_orm(q=q)
Example #4
0
def test_orm_subquery(dburl):
    count = func.count().label("count")

    with S(dburl, echo=ECHO) as s:
        sq = (
            s.query(Author.id, count)
            .join(Author.books)
            .group_by(Author.id)
            .subquery("sq")
        )
        q = (
            s.query(Author)
            .join(sq, sq.c.id == Author.id)
            .with_entities(sq.c.count, Author)
            .order_by(desc(sq.c.count), Author.name, Author.id)
        )
        check_paging_orm(q=q)
Example #5
0
def dburl(request):
    count = 10
    data = []

    for x in range(count):
        b = Book(name='Book {}'.format(x), a=x, b=x % 2, c=count - x, d=99)

        if x == 1:
            b.a = None
            b.author = Author(name='Willy Shakespeare')

        data.append(b)

    with temporary_database(request.param) as dburl:
        with S(dburl) as s:
            Base.metadata.create_all(s.connection())
            s.add_all(data)
        yield dburl
Example #6
0
def do_migration(shelf_filename):
    REAL_DB_URL = os.environ['DATABASE_URL'].replace('reconnect=true', '')

    ATTEMPTING = get_shelved_result(shelf_filename)

    print('APPLYING PENDING FILE')

    with S(REAL_DB_URL) as s_real:
        for migration_sql in ATTEMPTING:
            if not DRY_RUN:
                print("EXECUTING:")
                print(migration_sql)
                raw_execute(s_real, migration_sql)
            else:
                print('DRY RUN, would apply:')
                print(migration_sql)
    print('SUCCESS: DATABASE UP TO DATE.')
    return 0
Example #7
0
def test_long_identifiers(db):
    with S(db) as s:

        for i in range(50, 70):
            ident = "x" * i
            truncated = "x" * min(i, 63)

            func = FUNC.replace("ordinary_f", ident)

            s.execute(func)

            i = get_inspector(s)

            expected_sig = '"public"."{}"(t text)'.format(truncated)

            f = i.functions[expected_sig]

            assert f.full_definition is not None
Example #8
0
def test_core2(dburl):
    with S(dburl, echo=ECHO) as s:
        sel = select([Book.score]).order_by(Book.id)
        check_paging_core(sel, s)

        sel = (select([Book.score
                       ]).order_by(Author.id - Book.id,
                                   Book.id).where(Author.id == Book.author_id))
        check_paging_core(sel, s)

        sel = (select([Book.author_id,
                       func.count()]).group_by(Book.author_id).order_by(
                           func.sum(Book.popularity)))
        check_paging_core(sel, s)

        v = func.sum(func.coalesce(Book.a, 0)) + func.min(Book.b)
        sel = (select([Book.author_id, func.count(),
                       v]).group_by(Book.author_id).order_by(v))
        check_paging_core(sel, s)
def test_migration(shelf_filename, cld_host, dbexport_host, username, password,
                   service_binding):
    PENDING = sql_from_folder(PENDING_FOLDER)

    with tempfolder() as tempf:
        outfile = os.path.join(tempf, 'schemadump.sql')
        do_schema_dump(outfile, cld_host, dbexport_host, username, password,
                       service_binding)

        for i in range(len(PENDING) + 1):
            ATTEMPTING = list(reversed(PENDING))[:i]
            ATTEMPTING.reverse()

            print("TESTING MIGRATION USING LAST {} MIGRATION FILES".format(i))

            with temporary_db() as dummy_db_url, temporary_db(
            ) as target_db_url:
                with S(dummy_db_url) as s_dummy:
                    load_sql_from_file(s_dummy, outfile)

                    try:
                        for migration_sql in ATTEMPTING:
                            raw_execute(s_dummy, migration_sql)
                    except DB_ERROR_TUPLE:
                        print(
                            'TRIED USING LAST {} PENDING FILES TO MIGRATE BUT THIS FAILED, MOVING TO NEXT'
                            .format(i))
                        continue

                load_from_app_model(target_db_url)

                if databases_are_equal(dummy_db_url, target_db_url):
                    print('SUCCESS WITH LAST {} PENDING FILES'.format(i))
                    shelve_result(shelf_filename, ATTEMPTING)
                    return 0
                else:
                    print(
                        'TRIED USING LAST {} PENDING FILES TO MIGRATE BUT THIS DOES NOT GIVE A CORRECT OUTCOME, MOVING TO NEXT'
                        .format(i))

        print('COULD NOT FIND A CORRECT MIGRATION PATH :(')
        return 1
Example #10
0
def joined_inheritance_dburl(request):
    with temporary_database(request.param, host="localhost") as dburl:
        with S(dburl) as s:
            JoinedInheritanceBase.metadata.create_all(s.connection())
            s.add_all(
                [
                    Mammal(
                        name="Human", vertebra_count=33, leg_count=2, nipple_count=2
                    ),
                    Mammal(name="Dog", vertebra_count=36, leg_count=4, nipple_count=10),
                    Invertebrate(name="Jellyfish"),
                    Invertebrate(name="Jellyfish"),
                    Arthropod(name="Spider", leg_count=8),
                    Arthropod(name="Ant", leg_count=6),
                    Arthropod(name="Scorpion", leg_count=8),
                    Arthropod(name="Beetle", leg_count=6),
                    Vertebrate(name="Snake", vertebra_count=300),
                ]
            )
        yield dburl
Example #11
0
def _dburl(request):
    count = 10
    data = []

    for x in range(count):
        b = Book(name="Book {}".format(x), a=x, b=x % 2, c=count - x, d=99)

        if x == 1:
            b.a = None
            b.author = Author(name="Willy Shakespeare", info="Old timer")

        data.append(b)

    for x in range(count):
        author = Author(name="Author {}".format(x),
                        info="Rank {}".format(count + 1 - x))
        abooks = []
        for y in range((2 * x) % 10):
            b = Book(
                name="Book {}-{}".format(x, y),
                a=x + y,
                b=(y * x) % 2,
                c=count - x,
                d=99 - y,
            )
            b.author = author
            if y % 4 != 0:
                b.prequel = abooks[(2 * y + 1) % len(abooks)]
            abooks.append(b)
            data.append(b)

    data += [
        Light(colour=Colour(i % 3), intensity=(i * 13) % 53, myint=i)
        for i in range(99)
    ]

    with temporary_database(request.param, host="localhost") as dburl:
        with S(dburl) as s:
            Base.metadata.create_all(s.connection())
            s.add_all(data)
        yield dburl
Example #12
0
def run(args):
    if args.connection_name:
        os.environ['TUNNELER_CONNECTION'] = args.connection_name

    try:
        import tasks
    except ImportError:
        import tunneler.tasks as tasks

    task_method = getattr(tasks, args.task_name)

    with connection_from_connection_name(args.connection_name) as REAL_DB_URL:
        print('connecting to: {}'.format(REAL_DB_URL))

        with S(REAL_DB_URL, echo=False) as s:
            s.execute('select 1')

        print('connection successful.')

        task_method(REAL_DB_URL)
        print('task complete, closing connection.')
Example #13
0
def test_inherit(db):
    with S(db) as s:
        i = get_inspector(s)

        if i.pg_version <= 9:
            return
        s.execute("""
CREATE TABLE entity_bindings (
    id BIGSERIAL,
    entity_type TEXT NOT NULL,
    entity_id BIGINT NOT NULL
);
CREATE TABLE entity_bindings_A (
    CONSTRAINT "entity_type must be A" CHECK("entity_type" = 'A'),
    UNIQUE("entity_id", "entity_type")
) INHERITS (entity_bindings)
;
CREATE TABLE entity_bindings_B (
    CONSTRAINT "entity_type must be B" CHECK("entity_type" = 'B'),
    UNIQUE("entity_id", "entity_type")
) INHERITS (entity_bindings)
;
CREATE TABLE entity_bindings_C (
    CONSTRAINT "entity_type must be C" CHECK("entity_type" = 'C'),
    UNIQUE("entity_id", "entity_type")
) INHERITS (entity_bindings)
;
        """)

        i = get_inspector(s)

        t = i.tables['"public"."entity_bindings"']
        tb = i.tables['"public"."entity_bindings_b"']

        assert tb.parent_table == '"public"."entity_bindings"'

        assert t.parent_table is None

        assert tb.is_inheritance_child_table is True
        assert tb.uses_partitioning is False
Example #14
0
def joined_inheritance_dburl(request):
    with temporary_database(request.param) as dburl:
        with S(dburl) as s:
            JoinedInheritanceBase.metadata.create_all(s.connection())
            s.add_all([
                Mammal(name='Human',
                       vertebra_count=33,
                       leg_count=2,
                       nipple_count=2),
                Mammal(name='Dog',
                       vertebra_count=36,
                       leg_count=4,
                       nipple_count=10),
                Invertebrate(name='Jellyfish'),
                Invertebrate(name='Jellyfish'),
                Arthropod(name='Spider', leg_count=8),
                Arthropod(name='Ant', leg_count=6),
                Arthropod(name='Scorpion', leg_count=8),
                Arthropod(name='Beetle', leg_count=6),
                Vertebrate(name='Snake', vertebra_count=300),
            ])
        yield dburl
Example #15
0
def test_kinds(db):
    with S(db) as s:
        s.execute(FUNC)

        i = get_inspector(s)
        f = i.functions['"public"."ordinary_f"(t text)']

        assert f.definition == "select\r\n1"
        assert f.kind == "f"

        if i.pg_version < 11:
            return
        s.execute(PROC)
        i = get_inspector(s)
        p = i.functions['"public"."proc"(a integer, b integer)']

        assert p.definition == "\nselect a, b;\n"
        assert p.kind == "p"

        assert (
            p.drop_statement ==
            'drop procedure if exists "public"."proc"(a integer, b integer);')
Example #16
0
def test_marker_and_bookmark_per_item(dburl):

    with S(dburl, echo=ECHO) as s:
        q = s.query(Book).order_by(Book.id)
        page = get_page(q, per_page=3)

        paging = page.paging
        assert len(page) == 3
        assert paging.get_marker_at(0) == ((1, ), False)
        assert paging.get_marker_at(1) == ((2, ), False)
        assert paging.get_marker_at(2) == ((3, ), False)

        paging_items = list(paging.items())
        assert len(paging_items) == 3
        for i, (key, book) in enumerate(paging_items):
            assert key == ((i + 1, ), False)
            assert book.id == i + 1

        assert paging.get_bookmark_at(0) == ">i:1"
        assert paging.get_bookmark_at(1) == ">i:2"
        assert paging.get_bookmark_at(2) == ">i:3"

        bookmark_items = list(paging.bookmark_items())
        assert len(bookmark_items) == 3
        for i, (key, book) in enumerate(bookmark_items):
            assert key == ">i:%d" % (i + 1)
            assert book.id == i + 1

        place, _ = paging.get_marker_at(2)
        page = get_page(q, per_page=3, before=place)

        paging = page.paging
        assert len(page) == 2
        assert paging.get_marker_at(0) == ((2, ), True)
        assert paging.get_marker_at(1) == ((1, ), True)

        assert paging.get_bookmark_at(0) == "<i:2"
        assert paging.get_bookmark_at(1) == "<i:1"
Example #17
0
def test_dynamic_timeout(db):
    def get_timeout():
        return -1

    for n in await_pg_notifications(
        db,
        ["hello", "hello2"],
        timeout=get_timeout,
        yield_on_timeout=True,
        notifications_as_list=True,
        handle_signals=SIGNALS_TO_HANDLE,
    ):
        if n is None:
            with S(db) as s:
                s.execute("notify hello, 'here is my message'")

        elif isinstance(n, int):
            sig = signal.Signals(n)
            assert sig.name == "SIGINT"
            assert n == signal.SIGINT
            break

        else:
            assert len(n) == 1
            _n = n[0]
            assert _n.channel == "hello"
            assert _n.payload == "here is my message"
            if sys.platform != 'win32':
                os.kill(os.getpid(), signal.SIGINT)
            else:
                break

    if sys.platform != 'win32':
        with raises(KeyboardInterrupt):
            for n in await_pg_notifications(
                db, "hello", timeout=0.1, yield_on_timeout=True
            ):
                os.kill(os.getpid(), signal.SIGINT)
Example #18
0
def test_orm_recursive_cte(pg_only_dburl):
    with S(pg_only_dburl, echo=ECHO) as s:
        # Start with "origins": books that don't have prequels
        seed = s.query(Book.id.label("id"),
                       Book.id.label("origin")).filter(Book.prequel is None)

        # Recurse by picking up sequels
        sequel = aliased(Book, name="sequel")
        recursive = seed.cte(recursive=True)
        recursive = recursive.union(
            s.query(sequel.id, recursive.c.origin).filter(
                sequel.prequel_id == recursive.c.id))

        # Count total books per origin
        count = func.count().label("count")
        origin = recursive.c.origin.label("origin")
        sq = s.query(origin, count).group_by(origin).cte(recursive=False)

        # Join to full book table
        q = (s.query(sq.c.count, Book).filter(Book.id == sq.c.origin).order_by(
            sq.c.count.desc(), Book.id))

        check_paging_orm(q=q)
Example #19
0
def test_enum_deps(db):
    ENUM_DEP_SAMPLE = """\
create type e as enum('a', 'b', 'c');

create table t(id integer primary key, category e);

create view v as select * from t;

"""
    with S(db) as s:
        s.execute(ENUM_DEP_SAMPLE)

        i = get_inspector(s)

        e = '"public"."e"'
        t = '"public"."t"'
        v = '"public"."v"'

        assert e in i.enums

        assert i.enums[e].dependents == [t, v]
        assert e in i.selectables[t].dependent_on
        assert e in i.selectables[v].dependent_on
def test_table_dependency_order(db):
    with S(db) as s:
        i = get_inspector(s)

        if i.pg_version <= 9:
            return
        s.execute(INHERITANCE)

        ii = get_inspector(s)

        dep_order = ii.dependency_order()

        assert list(ii.tables.keys()) == [
            '"public"."child"',
            '"public"."normal"',
            '"public"."parent"',
        ]

        assert dep_order == [
            '"public"."parent"',
            '"public"."normal"',
            '"public"."child"',
        ]
Example #21
0
def test_triggers(db):
    with S(db) as s:
        s.execute(BASE)
        i = get_inspector(s)

        i.triggers['"public"."view_on_table"."trigger_on_view"']
Example #22
0
def do_fixture_test(fixture_name):
    fixture_path = 'tests/FIXTURES/{}/'.format(fixture_name)
    EXPECTED = io.open(fixture_path + 'expected.sql').read().strip()

    with temporary_database() as d0, temporary_database() as d1:
        with S(d0) as s0, S(d1) as s1:
            load_sql_from_file(s0, fixture_path + 'a.sql')
            load_sql_from_file(s1, fixture_path + 'b.sql')

        args = parse_args([d0, d1])
        assert not args.unsafe

        out, err = outs()
        assert run(args, out=out, err=err) == 3

        assert out.getvalue() == ''
        assert err.getvalue() == '-- ERROR: destructive statements generated. Use the --unsafe flag to suppress this error.\n'

        args = parse_args(['--unsafe', d0, d1])
        assert args.unsafe

        out, err = outs()
        assert run(args, out=out, err=err) == 2
        assert err.getvalue() == ''
        assert out.getvalue().strip() == EXPECTED

        ADDITIONS = io.open(fixture_path + 'additions.sql').read().strip()
        EXPECTED2 = io.open(fixture_path + 'expected2.sql').read().strip()

        if ADDITIONS:
            with S(d0) as s0, S(d1) as s1:
                m = Migration(s0, s1)
                m.inspect_from()
                m.inspect_target()

                with raises(AttributeError):
                    m.changes.nonexist

                m.set_safety(False)

                m.add_sql(ADDITIONS)
                m.apply()
                m.add_all_changes()
                assert m.sql.strip() == EXPECTED2  # sql generated OK
                m.apply()

                # check for changes again and make sure none are pending
                m.add_all_changes()
                assert m.changes.i_from == m.changes.i_target
                assert not m.statements  # no further statements to apply
                assert m.sql == ''

                out, err = outs()
                assert run(args, out=out, err=err) == 0

        # test alternative parameters

        with S(d0) as s0, S(d1) as s1:
            m = Migration(
                get_inspector(s0),
                get_inspector(s1)
            )

        # test empty
        m = Migration(None, None)
        m.add_all_changes()

        with raises(AttributeError):
            m.s_from

        with raises(AttributeError):
            m.s_target

        args = parse_args(['--unsafe', 'EMPTY', 'EMPTY'])
        out, err = outs()
        assert run(args, out=out, err=err) == 0
Example #23
0
def creation_statements(db_url, **kwargs):
    with S(db_url, poolclass=NullPool) as s:
        m = Migration(None, s, **kwargs)
        m.set_safety(False)
        m.add_all_changes()
        return m.sql
Example #24
0
def test_rls(db):
    with S(db) as s:
        s.execute("""
CREATE TABLE t(id uuid, a text, b decimal);
        """)

        i = get_inspector(s)
        t = i.tables['"public"."t"']

        assert t.rowsecurity is False
        assert (t.alter_rls_statement ==
                'alter table "public"."t" disable row level security;')

        t.rowsecurity = True
        s.execute(t.alter_rls_statement)
        i = get_inspector(s)
        t = i.tables['"public"."t"']
        assert t.rowsecurity is True
        assert (t.alter_rls_statement ==
                'alter table "public"."t" enable row level security;')

        create_role(s, schemainspect_test_role)

        s.execute(f"""

CREATE TABLE accounts (manager text, company text, contact_email text);

ALTER TABLE accounts ENABLE ROW LEVEL SECURITY;

CREATE POLICY account_managers ON accounts TO {schemainspect_test_role}
    USING (manager = current_user);

create policy "insert_gamer"
on accounts
as permissive
for insert
to {schemainspect_test_role}
with check (manager = current_user);

        """)
        i = get_inspector(s)

        pname = '"public"."accounts"."account_managers"'
        t = i.rlspolicies[pname]
        assert t.name == "account_managers"
        assert t.schema == "public"
        assert t.table_name == "accounts"
        assert t.commandtype == "*"
        assert t.permissive is True
        assert t.roles == ["schemainspect_test_role"]
        assert t.qual == "(manager = (CURRENT_USER)::text)"
        assert t.withcheck is None

        assert (t.create_statement == """create policy "account_managers"
on "public"."accounts"
as permissive
for all
to schemainspect_test_role
using (manager = (CURRENT_USER)::text);
""")

        assert (t.drop_statement ==
                'drop policy "account_managers" on "public"."accounts";')

        s.execute(t.drop_statement)
        s.execute(t.create_statement)
        i = get_inspector(s)
        t = i.rlspolicies[pname]
        assert t.name == "account_managers"
        assert t.schema == "public"
        assert t.table_name == "accounts"
        assert t.commandtype == "*"
        assert t.permissive is True
        assert t.roles == ["schemainspect_test_role"]
        assert t.qual == "(manager = (CURRENT_USER)::text)"
        assert t.withcheck is None

        pname = '"public"."accounts"."insert_gamer"'
        t = i.rlspolicies[pname]
        assert t.name == "insert_gamer"
        assert t.schema == "public"
        assert t.table_name == "accounts"
        assert t.commandtype == "a"
        assert t.permissive is True
        assert t.roles == ["schemainspect_test_role"]
        assert t.withcheck == "(manager = (CURRENT_USER)::text)"
        assert t.qual is None

        assert (t.create_statement == """create policy "insert_gamer"
on "public"."accounts"
as permissive
for insert
to schemainspect_test_role
with check (manager = (CURRENT_USER)::text);
""")

        assert t.drop_statement == 'drop policy "insert_gamer" on "public"."accounts";'
Example #25
0
 def transaction(self):
     with S(*self._args, **self._kwargs) as s:
         yield transaction(s)
Example #26
0
def test_postgres_inspect(db):
    with S(db) as s:
        setup_pg_schema(s)
        i = get_inspector(s)
        asserts_pg(i)
        assert i == i == get_inspector(s)
Example #27
0
def test_pendulum_for_time_types(db):
    t = pendulum.parse("2017-12-31 23:34:45", tz="Australia/Melbourne")
    i = relativedelta(days=1, seconds=200, microseconds=99)

    with S(db) as s:
        c = raw_connection(s)
        cu = c.cursor()

        cu.execute("""
            select
                null::timestamp,
                null::timestamptz,
                null::date,
                null::time,
                null::interval
        """)

        descriptions = cu.description
        oids = [x[1] for x in descriptions]

        use_pendulum_for_time_types()

        s.execute("""
            create temporary table dt(
                ts timestamp,
                tstz timestamptz,
                d date,
                t time,
                i interval)
        """)

        s.execute(
            """
            insert into dt(ts, tstz, d, t, i)
            values
            (:ts,
            :tstz,
            :d,
            :t,
            :i)
        """,
            {
                "ts": vanilla(t),
                "tstz": t.in_timezone("Australia/Sydney"),
                "d": t.date(),
                "t": t.time(),
                "i": i,
            },
        )

        out = list(s.execute("""select * from dt"""))[0]

        assert out.ts == naive(t.in_tz("UTC"))
        assert out.tstz == t.in_timezone("UTC")
        assert out.d == t.date()
        assert out.t == t.time()
        assert out.i == i

        result = s.execute("""
            select
                null::timestamp,
                null::timestamptz,
                null::date,
                null::time,
                null::interval
        """)

        out = list(result)[0]
        assert list(out) == [None, None, None, None, None]
Example #28
0
 def raw(self, sql):
     with S(*self._args, **self._kwargs) as s:
         _results = raw_execute(s, sql)
         return _results
Example #29
0
def do_fixture_test(fixture_name,
                    schema=None,
                    create_extensions_only=False,
                    with_privileges=False):
    flags = ["--unsafe"]
    if schema:
        flags += ["--schema", schema]
    if create_extensions_only:
        flags += ["--create-extensions-only"]
    if with_privileges:
        flags += ["--with-privileges"]
    fixture_path = "tests/FIXTURES/{}/".format(fixture_name)
    EXPECTED = io.open(fixture_path + "expected.sql").read().strip()
    with temporary_database(host="localhost") as d0, temporary_database(
            host="localhost") as d1:
        with S(d0) as s0:
            create_role(s0, schemainspect_test_role)
        with S(d0) as s0, S(d1) as s1:
            load_sql_from_file(s0, fixture_path + "a.sql")
            load_sql_from_file(s1, fixture_path + "b.sql")
        args = parse_args([d0, d1])
        assert not args.unsafe
        assert args.schema is None
        out, err = outs()
        assert run(args, out=out, err=err) == 3
        assert out.getvalue() == ""

        DESTRUCTIVE = "-- ERROR: destructive statements generated. Use the --unsafe flag to suppress this error.\n"

        assert err.getvalue() == DESTRUCTIVE

        args = parse_args(flags + [d0, d1])
        assert args.unsafe
        assert args.schema == schema
        out, err = outs()
        assert run(args, out=out, err=err) == 2
        assert err.getvalue() == ""
        assert out.getvalue().strip() == EXPECTED
        ADDITIONS = io.open(fixture_path + "additions.sql").read().strip()
        EXPECTED2 = io.open(fixture_path + "expected2.sql").read().strip()

        with S(d0) as s0, S(d1) as s1:
            m = Migration(s0, s1, schema=schema)
            m.inspect_from()
            m.inspect_target()
            with raises(AttributeError):
                m.changes.nonexist
            m.set_safety(False)
            if ADDITIONS:
                m.add_sql(ADDITIONS)
            m.apply()

            if create_extensions_only:
                m.add_extension_changes(drops=False)
            else:
                m.add_all_changes(privileges=with_privileges)

            expected = EXPECTED2 if ADDITIONS else EXPECTED
            assert m.sql.strip() == expected  # sql generated OK

            m.apply()
            # check for changes again and make sure none are pending
            if create_extensions_only:
                m.add_extension_changes(drops=False)
                assert (m.changes.i_from.extensions.items() >=
                        m.changes.i_target.extensions.items())
            else:
                m.add_all_changes(privileges=with_privileges)
                assert m.changes.i_from == m.changes.i_target
            assert not m.statements  # no further statements to apply
            assert m.sql == ""
            out, err = outs()

        assert run(args, out=out, err=err) == 0
        # test alternative parameters
        with S(d0) as s0, S(d1) as s1:
            m = Migration(get_inspector(s0), get_inspector(s1))
        # test empty
        m = Migration(None, None)
        m.add_all_changes(privileges=with_privileges)
        with raises(AttributeError):
            m.s_from
        with raises(AttributeError):
            m.s_target
        args = parse_args(flags + ["EMPTY", "EMPTY"])
        out, err = outs()
        assert run(args, out=out, err=err) == 0
Example #30
0
 def inspect(self):
     with S(*self._args, **self._kwargs) as s:
         i = get_inspector(s)
     return i