def test_printers_roundtrip(src, lineno, statement):
    try:
        orig_ast = parse_sql(statement)
    except:  # noqa
        raise RuntimeError("%s:%d:Could not parse %r" % (src, lineno, statement))

    _remove_stmt_len_and_location(orig_ast)

    serialized = RawStream()(Node(orig_ast))
    try:
        serialized_ast = parse_sql(serialized)
    except:  # noqa
        raise RuntimeError("%s:%d:Could not reparse %r" % (src, lineno, serialized))
    _remove_stmt_len_and_location(serialized_ast)
    assert orig_ast == serialized_ast, "%s:%s:%r != %r" % (src, lineno, statement, serialized)

    indented = IndentedStream()(Node(orig_ast))
    try:
        indented_ast = parse_sql(indented)
    except:  # noqa
        raise RuntimeError("%s:%d:Could not reparse %r" % (src, lineno, indented))
    _remove_stmt_len_and_location(indented_ast)
    assert orig_ast == indented_ast, "%s:%d:%r != %r" % (src, lineno, statement, indented)

    # Run ``pytest -s tests/`` to see the following output
    print()
    print(indented)
Beispiel #2
0
def _emit_ast(ast: "Node") -> str:
    # required to instantiate all pglast node printers.
    # noinspection PyUnresolvedReferences
    from pglast import printers  # noqa

    stream = IndentedStream()
    return str(stream(ast))
Beispiel #3
0
 def string_statements(self):
     """
     Returns:
         generator: a generator deparsing the `pglast.Node`
         representing staements to their SQL string representation.
     """
     for statement in self.statements:
         yield IndentedStream()(statement)
Beispiel #4
0
def test_prettification(sample):
    parts = sample.split('\n=\n')
    original = parts[0].strip()
    parts = parts[1].split('\n:\n')
    expected = parts[0].strip()
    if len(parts) == 2:
        options = literal_eval(parts[1])
    else:
        options = {}
    prettified = IndentedStream(**options)(original)
    assert expected == prettified, "%r != %r" % (expected, prettified)
def test_prettification(src, lineno, case):
    parts = case.split('\n=\n')
    original = parts[0].strip()
    parts = parts[1].split('\n:\n')
    expected = parts[0].strip().replace('\\n\\\n', '\n')
    if len(parts) == 2:
        options = literal_eval(parts[1])
    else:
        options = {}
    prettified = IndentedStream(**options)(original)
    assert expected == prettified, "%s:%d:%r != %r" % (src, lineno, expected, prettified)
Beispiel #6
0
def post_data_exec(target_dsn, statements, njobs=4):
    logging.getLogger().info("post_data_exec start")
    multiprocessing.log_to_stderr()
    pool2 = multiprocessing.Pool(njobs)
    j = 0
    for stmt in statements:
        if should_apply_post_data_stmt(stmt):
            pool2.apply_async(
                post_data_task,
                (target_dsn, IndentedStream(expression_level=1)(stmt)))
            j = j + 1
    logging.getLogger().info("post_data_exec: jobs={0}, stmts={1}".format(
        njobs, j))
    pool2.close()
    pool2.join()
    logging.getLogger().info("post_data_exec end")
Beispiel #7
0
def roundtrip(sql):
    orig_ast = parse_sql(sql)
    _remove_stmt_len_and_location(orig_ast)

    serialized = RawStream()(Node(orig_ast))
    try:
        serialized_ast = parse_sql(serialized)
    except:  # noqa
        raise RuntimeError("Could not reparse %r" % serialized)
    _remove_stmt_len_and_location(serialized_ast)
    assert orig_ast == serialized_ast, "%r != %r" % (sql, serialized)

    indented = IndentedStream()(Node(orig_ast))
    try:
        indented_ast = parse_sql(indented)
    except:  # noqa
        raise RuntimeError("Could not reparse %r" % indented)
    _remove_stmt_len_and_location(indented_ast)
    assert orig_ast == indented_ast, "%r != %r" % (sql, indented)

    # Run ``pytest -s tests/`` to see the following output
    print()
    print(indented)
Beispiel #8
0
    def restore_schema(self, blacklist_objects=None, failable_objects=None):
        """
        Restore the DDLScript in the target DB.

        Arguments:
          blacklist_objects (list): a list of fully qualified object names
            that we shouldn't restore.

          failable_objects (list): a list of objects for which we ignore failures
            during restoration.

        Returns:
          list: a list of statements to execute after data has been restored.
          This includes table constraints and indexes.

        The restoration only restores table definitions, and specifically
        ignore triggers, event triggers, object privileges, views, and RLS.
        """
        target_conn = psycopg2.connect(self.target_dsn)
        cursor = target_conn.cursor()
        cursor.execute("BEGIN")
        self._create_missing_schemas(cursor)
        blacklist_objects = blacklist_objects or []
        failable_objects = failable_objects or []
        # Skip certain kind of statements: we are not interested in non-unique
        # indexes, nor in triggers
        post_data = []

        for statement in self.ddlscript.statements:
            if statement.node_tag in ('CreateTrigStmt', 'CreateEventTrigStmt',
                                      'GrantStmt',
                                      'AlterDefaultPrivilegesStmt',
                                      'CreatePolicyStmt', 'CommentStmt',
                                      'CreateCastStmt', 'AlterOwnerStmt'):

                continue

            if statement.node_tag == 'IndexStmt':
                post_data.append(statement)

                continue

            if statement.node_tag == 'AlterTableStmt':
                # Ignore ALTER INDEX statement altogether
                if statement.relkind == ObjectType.OBJECT_INDEX:
                    continue

                # We assume that if we have an add constraint command, it's
                # the only one. This is safe because we only expect to work
                # on pg_dump outputs here.
                if any(cmd.subtype.value == AlterTableType.AT_AddConstraint
                       for cmd in statement.cmds):
                    post_data.append(statement)

                    continue

                if any(cmd.subtype.value == AlterTableType.AT_ClusterOn
                       for cmd in statement.cmds):
                    post_data.append(statement)

                    continue

                # Just ignore ALTER TABLE .. OWNER TO
                if any(cmd.subtype.value == AlterTableType.AT_ChangeOwner
                       for cmd in statement.cmds):

                    continue

                # Just ignore ALTER TABLE .. ROW LEVEL SECURITY
                if any(cmd.subtype.value == AlterTableType.AT_EnableRowSecurity
                       for cmd in statement.cmds):

                    continue

                # Just ignore ALTER TABLE ..SET  REPLICA FULL
                if any(cmd.subtype.value == AlterTableType.AT_ReplicaIdentity
                       for cmd in statement.cmds):

                    continue

            str_statement = IndentedStream(expression_level=1)(statement)

            obj_creation_tuple = object_creation(statement)
            obj_creation = None
            if obj_creation_tuple:
                obj_creation = obj_creation_tuple[1]

            if obj_creation in blacklist_objects:
                logging.debug("Ignore %s" % str_statement)

                continue

            savepoint_name = None
            if obj_creation in failable_objects:
                savepoint_name = 'svpt_' + uuid.uuid4().hex
                cursor.execute("SAVEPOINT %s" % savepoint_name)

            logging.debug("Restoring %s" % str_statement)
            try:
                cursor.execute(str_statement)
            except psycopg2.Error:
                if savepoint_name:
                    logging.debug("Rollback to savepoint for %s",
                                  str_statement)
                    cursor.execute("ROLLBACK TO SAVEPOINT %s" % savepoint_name)
                else:
                    raise

        cursor.execute("COMMIT")

        return post_data