Exemple #1
0
    def link_refs(self, dbcolumns, dbconstrs, dbfuncs):
        """Connect various objects to their corresponding types or domains

        :param dbcolumns: dictionary of columns
        :param dbconstrs: dictionary of constraints
        :param dbfuncs: dictionary of functions

        Fills the `check_constraints` dictionaries for each domain by
        traversing the `dbconstrs` dictionary. Fills the attributes
        list for composite types. Fills the dependent functions
        dictionary for base types.
        """
        for (sch, typ) in dbcolumns:
            if (sch, typ) in self:
                assert isinstance(self[(sch, typ)], Composite)
                self[(sch, typ)].attributes = dbcolumns[(sch, typ)]
                for attr in dbcolumns[(sch, typ)]:
                    attr._type = self[(sch, typ)]
        for (sch, typ, cns) in dbconstrs:
            constr = dbconstrs[(sch, typ, cns)]
            if not hasattr(constr, "target") or constr.target != "d":
                continue
            assert self[(sch, typ)]
            dbtype = self[(sch, typ)]
            if isinstance(constr, CheckConstraint):
                if not hasattr(dbtype, "check_constraints"):
                    dbtype.check_constraints = {}
                dbtype.check_constraints.update({cns: constr})
        for (sch, typ) in self:
            dbtype = self[(sch, typ)]
            if isinstance(dbtype, BaseType):
                if not hasattr(dbtype, "dep_funcs"):
                    dbtype.dep_funcs = {}
                (sch, infnc) = split_schema_obj(dbtype.input, sch)
                args = "cstring"
                if not (sch, infnc, args) in dbfuncs:
                    args = "cstring, oid, integer"
                func = dbfuncs[(sch, infnc, args)]
                dbtype.dep_funcs.update({"input": func})
                func._dep_type = dbtype
                (sch, outfnc) = split_schema_obj(dbtype.output, sch)
                func = dbfuncs[(sch, outfnc, dbtype.qualname())]
                dbtype.dep_funcs.update({"output": func})
                func._dep_type = dbtype
                for attr in OPT_FUNCS:
                    if hasattr(dbtype, attr):
                        (sch, fnc) = split_schema_obj(getattr(dbtype, attr), sch)
                        if attr == "receive":
                            arg = "internal"
                        elif attr == "send":
                            arg = dbtype.qualname()
                        elif attr == "typmod_in":
                            arg = "cstring[]"
                        elif attr == "typmod_out":
                            arg = "integer"
                        elif attr == "analyze":
                            arg = "internal"
                        func = dbfuncs[(sch, fnc, arg)]
                        dbtype.dep_funcs.update({attr: func})
                        func._dep_type = dbtype
Exemple #2
0
    def get_implied_deps(self, db):
        deps = super(Operator, self).get_implied_deps(db)

        # Types may be not found because builtin, or the operator unary
        if self.leftarg is not None:
            leftarg = db.types.find(self.leftarg)
            if leftarg:
                deps.add(leftarg)

        if self.rightarg is not None:
            rightarg = db.types.find(self.rightarg)
            if rightarg:
                deps.add(rightarg)

        # The function instead we expect it exists
        # TODO: another ugly hack to locate the object
        fschema, fname = split_schema_obj(self.procedure, self.schema)
        fargs = ', '.join(t for t in [self.leftarg, self.rightarg]
                          if t is not None)
        if (fschema, fname, fargs) in db.functions:
            func = db.functions[fschema, fname, fargs]
            deps.add(func)

        # This helper function may be a builtin
        if self.restrict is not None:
            fschema, fname = split_schema_obj(self.restrict)
            func = db.functions.get((fschema, fname,
                                    "internal, oid, internal, integer"))
            if func:
                deps.add(func)

        return deps
Exemple #3
0
 def _from_catalog(self):
     """Initialize the dictionary of constraints by querying the catalogs"""
     for constr in self.fetch():
         constr.unqualify()
         sch, tbl, cns = constr.key()
         sch, tbl = split_schema_obj('%s.%s' % (sch, tbl))
         constr_type = constr.type
         del constr.type
         if constr_type != 'f':
             del constr.ref_table
             del constr.on_update
             del constr.on_delete
         if constr_type == 'c':
             self[(sch, tbl, cns)] = CheckConstraint(**constr.__dict__)
         elif constr_type == 'p':
             self[(sch, tbl, cns)] = PrimaryKey(**constr.__dict__)
         elif constr_type == 'f':
             # normalize reference schema/table:
             # if reftbl is qualified, split the schema out,
             # otherwise it's in the 'public' schema (set as default
             # when connecting)
             if constr.on_update == 'a':
                 del constr.on_update
             else:
                 constr.on_update = ACTIONS[constr.on_update]
             if constr.on_delete == 'a':
                 del constr.on_delete
             else:
                 constr.on_delete = ACTIONS[constr.on_delete]
             reftbl = constr.ref_table
             (constr.ref_schema, constr.ref_table) = split_schema_obj(
                 reftbl)
             self[(sch, tbl, cns)] = ForeignKey(**constr.__dict__)
         elif constr_type == 'u':
             self[(sch, tbl, cns)] = UniqueConstraint(**constr.__dict__)
Exemple #4
0
    def get_implied_deps(self, db):
        # List the previous dependencies
        deps = super(Aggregate, self).get_implied_deps(db)

        sch, fnc = split_schema_obj(self.sfunc)
        args = self.stype + ', ' + self.arguments
        deps.add(db.functions[sch, fnc, args])
        if hasattr(self, 'finalfunc'):
            sch, fnc = split_schema_obj(self.finalfunc)
            deps.add(db.functions[sch, fnc, self.stype])

        return deps
Exemple #5
0
 def _from_catalog(self):
     """Initialize the dictionary of tables by querying the catalogs"""
     if self.dbconn.version < 90100:
         self.query = QUERY_PRE91
     elif self.dbconn.version < 90300:
         self.query = QUERY_PRE93
     for table in self.fetch():
         oid = table.oid
         sch, tbl = table.key()
         if hasattr(table, 'persistence'):
             if table.persistence == 'u':
                 table.unlogged = True
             del table.persistence
         kind = table.kind
         del table.kind
         if kind == 'r':
             self.by_oid[oid] = self[sch, tbl] = Table(**table.__dict__)
         elif kind == 'S':
             self.by_oid[oid] = self[sch, tbl] = inst \
                 = Sequence(**table.__dict__)
             inst.get_attrs(self.dbconn)
             inst.get_dependent_table(self.dbconn)
         elif kind == 'v':
             self.by_oid[oid] = self[sch, tbl] = View(**table.__dict__)
         elif kind == 'm':
             self.by_oid[oid] = self[sch, tbl] \
                 = MaterializedView(**table.__dict__)
     inhtbls = self.dbconn.fetchall(self.inhquery)
     self.dbconn.rollback()
     for (tbl, partbl, num) in inhtbls:
         (sch, tbl) = split_schema_obj(tbl)
         table = self[(sch, tbl)]
         if not hasattr(table, 'inherits'):
             table.inherits = []
         table.inherits.append(partbl)
Exemple #6
0
 def _from_catalog(self):
     """Initialize the dictionary of tables by querying the catalogs"""
     if self.dbconn.version < 90100:
         self.query = QUERY_PRE91
     elif self.dbconn.version < 90300:
         self.query = QUERY_PRE93
     for table in self.fetch():
         sch, tbl = table.key()
         if hasattr(table, 'privileges'):
             table.privileges = table.privileges.split(',')
         if hasattr(table, 'persistence'):
             if table.persistence == 'u':
                 table.unlogged = True
             del table.persistence
         kind = table.kind
         del table.kind
         if kind == 'r':
             self[(sch, tbl)] = Table(**table.__dict__)
         elif kind == 'S':
             self[(sch, tbl)] = inst = Sequence(**table.__dict__)
             inst.get_attrs(self.dbconn)
             inst.get_dependent_table(self.dbconn)
         elif kind == 'v':
             self[(sch, tbl)] = View(**table.__dict__)
         elif kind == 'm':
             self[(sch, tbl)] = MaterializedView(**table.__dict__)
     inhtbls = self.dbconn.fetchall(self.inhquery)
     self.dbconn.rollback()
     for (tbl, partbl, num) in inhtbls:
         (sch, tbl) = split_schema_obj(tbl)
         table = self[(sch, tbl)]
         if not hasattr(table, 'inherits'):
             table.inherits = []
         table.inherits.append(partbl)
Exemple #7
0
    def link_refs(self, dbtypes, dbeventtrigs):
        """Connect the functions to other objects

        - Connect event triggers to the functions executed
        - Connect defining functions to the type they define

        :param dbtypes: dictionary of types
        :param dbeventtrigs: dictionary of event triggers

        Fills in the `event_triggers` list for each function by
        traversing the `dbeventtrigs` dictionary.
        """
        for key in dbeventtrigs:
            evttrg = dbeventtrigs[key]
            (sch, fnc) = split_schema_obj(evttrg.procedure)
            func = self[(sch, fnc[:-2], '')]
            if not hasattr(func, 'event_triggers'):
                func.event_triggers = []
            func.event_triggers.append(evttrg.name)

        # TODO: this link is needed from map, not from sql.
        # is this a pattern? I was assuming link_refs would have disappeared
        # but I'm actually still maintaining them. Verify if they are always
        # only used for from_map, not for from_catalog
        for key in dbtypes:
            t = dbtypes[key]
            for f in t.find_defining_funcs(self):
                f._defining = t
Exemple #8
0
 def _from_catalog(self):
     """Initialize the dictionary of indexes by querying the catalogs"""
     for index in self.fetch():
         index.unqualify()
         sch, tbl, idx = index.key()
         sch, tbl = split_schema_obj('%s.%s' % (sch, tbl))
         if index.keycols != '0':
             index.columns = []
             for col in index.defn[index.defn.rfind('(') + 1:-1].split(','):
                 opts = col.lstrip().split()
                 nm = opts[0]
                 extra = {}
                 for i, opt in enumerate(opts[1:]):
                     if opt.upper() not in ['ASC', 'DESC', 'NULLS',
                                            'FIRST', 'LAST']:
                         extra.update(opclass=opt)
                         continue
                     elif opt == 'NULLS':
                         extra.update(nulls=opts[i + 2].lower())
                     elif opt == 'DESC':
                         extra.update(order='desc')
                     else:
                         continue
                 if extra:
                     index.columns.append({nm: extra})
                 else:
                     index.columns.append(nm)
         del index.defn, index.keycols
         self[(sch, tbl, idx)] = index
Exemple #9
0
 def _from_catalog(self):
     """Initialize the dictionary of constraints by querying the catalogs"""
     if self.dbconn.version < 90300:
         self.match_types = MATCHTYPES_PRE93
     for constr in self.fetch():
         constr.unqualify()
         oid = constr.oid
         sch, tbl, cns = constr.key()
         sch, tbl = split_schema_obj('%s.%s' % (sch, tbl))
         constr_type = constr.type
         del constr.type
         if constr_type != 'f':
             del constr.ref_table
             del constr.on_update
             del constr.on_delete
             del constr.match
         if constr_type == 'c':
             self.by_oid[oid] = self[(sch, tbl, cns)] \
                 = CheckConstraint(**constr.__dict__)
         elif constr_type == 'p':
             self.by_oid[oid] = self[(sch, tbl, cns)] \
                 = PrimaryKey(**constr.__dict__)
         elif constr_type == 'f':
             # normalize reference schema/table:
             # if reftbl is qualified, split the schema out,
             # otherwise it's in the 'public' schema (set as default
             # when connecting)
             if constr.on_update == 'a':
                 del constr.on_update
             else:
                 constr.on_update = ACTIONS[constr.on_update]
             if constr.on_delete == 'a':
                 del constr.on_delete
             else:
                 constr.on_delete = ACTIONS[constr.on_delete]
             if self.match_types[constr.match] == 'simple':
                 del constr.match
             else:
                 constr.match = self.match_types[constr.match]
             reftbl = constr.ref_table
             (constr.ref_schema,
              constr.ref_table) = split_schema_obj(reftbl)
             self.by_oid[oid] = self[(sch, tbl, cns)] \
                 = ForeignKey(**constr.__dict__)
         elif constr_type == 'u':
             self.by_oid[oid] = self[(sch, tbl, cns)] \
                 = UniqueConstraint(**constr.__dict__)
Exemple #10
0
    def get_implied_deps(self, db):
        deps = super(Domain, self).get_implied_deps(db)

        # depend on the base type
        # don't give errors in case it's a builtin
        tschema, tname = split_schema_obj(self.type)
        type = db.types.get((tschema, tname))
        if type:
            deps.add(type)

            # In my testing database there is a dependency on the output
            # function of the base type. TODO: investigate more.
            if hasattr(type, 'output'):
                fschema, fname = split_schema_obj(type.output)
                func = db.functions[fschema, fname, type.qualname()]
                deps.add(func)

        return deps
Exemple #11
0
 def _from_catalog(self):
     """Initialize the dictionary of indexes by querying the catalogs"""
     for index in self.fetch():
         index.unqualify()
         sch, tbl, idx = index.key()
         sch, tbl = split_schema_obj("%s.%s" % (sch, tbl))
         if index.keycols == "0":
             del index.keycols
         self[(sch, tbl, idx)] = index
Exemple #12
0
    def find(self, oper):
        """Return an operator given its signature

        :param oper: a signature such as '#>=#(hstore,hstore)'

        Return the operator found, else None.
        """
        schema, name = split_schema_obj(oper)
        name, args = split_func_args(name)
        return self.get((schema, name) + tuple(args))
Exemple #13
0
    def find(self, obj, schema=None):
        """Find a table given its name.

        The name can contain array type modifiers such as '[]'

        Return None if not found.
        """
        sch, name = split_schema_obj(obj, schema)
        name = name.rstrip('[]')
        return self.get((sch, name))
Exemple #14
0
 def find_defining_funcs(self, dbfuncs):
     rv = []
     for attr, arg in [('input', 'cstring'), ('output', self.qualname()),
                       ('receive', 'internal'), ('send', self.qualname())]:
         f = getattr(self, attr, None)
         if not f:
             continue
         fschema, fname = split_schema_obj(f)
         rv.append(dbfuncs[fschema, fname, arg])
     return rv
Exemple #15
0
    def find(self, obj):
        """Find a type given its name.

        The name can contain modifiers such as arrays '[]' and attibutes '(3)'

        Return None if not found.
        """
        schema, name = split_schema_obj(obj)
        name = name.rstrip('[](,)0123456789')
        return self.get((schema, name))
Exemple #16
0
    def find(self, func, args):
        """Return a function given its name and arguments

        :param func: name of the function, eventually with schema
        :param args: list of type names

        Return the function found, else None.
        """
        schema, name = split_schema_obj(func)
        args = ', '.join(args)
        return self.get((schema, name, args))
Exemple #17
0
    def get_dependent_table(self, dbconn):
        """Get the table and column name that uses or owns the sequence

        :param dbconn: a DbConnection object
        """
        data = dbconn.fetchone(
            """SELECT refobjid::regclass, refobjsubid
               FROM pg_depend
               WHERE objid = '%s'::regclass
                 AND refclassid = 'pg_class'::regclass""" % self.qualname())
        if data:
            (sch, self.owner_table) = split_schema_obj(data[0], self.schema)
            self.owner_column = data[1]
            return
        data = dbconn.fetchone(
            """SELECT adrelid::regclass
               FROM pg_attrdef a JOIN pg_depend ON (a.oid = objid)
               WHERE refobjid = '%s'::regclass
               AND classid = 'pg_attrdef'::regclass""" % self.qualname())
        if data:
            (sch, self.dependent_table) = split_schema_obj(
                data[0], self.schema)
Exemple #18
0
    def link_refs(self, dbeventtrigs):
        """Connect event triggers to the functions executed

        :param dbeventtrigs: dictionary of event triggers

        Fills in the `event_triggers` list for each function by
        traversing the `dbeventtrigs` dictionary.
        """
        for key in dbeventtrigs:
            evttrg = dbeventtrigs[key]
            (sch, fnc) = split_schema_obj(evttrg.procedure)
            func = self[(sch, fnc[:-2], '')]
            if not hasattr(func, 'event_triggers'):
                func.event_triggers = []
            func.event_triggers.append(evttrg.name)
Exemple #19
0
    def link_refs(self, dbeventtrigs):
        """Connect event triggers to the functions executed

        :param dbeventtrigs: dictionary of event triggers

        Fills in the `event_triggers` list for each function by
        traversing the `dbeventtrigs` dictionary.
        """
        for key in dbeventtrigs:
            evttrg = dbeventtrigs[key]
            (sch, fnc) = split_schema_obj(evttrg.procedure)
            func = self[(sch, fnc[:-2], '')]
            if not hasattr(func, 'event_triggers'):
                func.event_triggers = []
            func.event_triggers.append(evttrg.name)
Exemple #20
0
    def get_implied_deps(self, db):
        deps = super(Trigger, self).get_implied_deps(db)

        deps.add(db.tables[self.schema, self.table])

        # short-circuit augment triggers
        if hasattr(self, '_iscfg'):
            return deps

        # the trigger procedure can have arguments, but the trigger definition
        # has always none (they are accessed through `tg_argv`).
        # TODO: this breaks if a function name contains a '('
        # (another case for a robust lookup function in db)
        fschema, fname = split_schema_obj(self.procedure, self.schema)
        fname, _ = fname.split('(', 1)  # implicitly assert there is a (
        if not fname.startswith('tsvector_update_trigger'):
            deps.add(db.functions[fschema, fname, ''])

        return deps
Exemple #21
0
 def _from_catalog(self):
     """Initialize the dictionary of tables by querying the catalogs"""
     for table in self.fetch():
         sch, tbl = table.key()
         kind = table.kind
         del table.kind
         if kind == 'r':
             self[(sch, tbl)] = Table(**table.__dict__)
         elif kind == 'S':
             self[(sch, tbl)] = inst = Sequence(**table.__dict__)
             inst.get_attrs(self.dbconn)
             inst.get_dependent_table(self.dbconn)
         elif kind == 'v':
             self[(sch, tbl)] = View(**table.__dict__)
     for (tbl, partbl, num) in self.dbconn.fetchall(self.inhquery):
         (sch, tbl) = split_schema_obj(tbl)
         table = self[(sch, tbl)]
         if not hasattr(table, 'inherits'):
             table.inherits = []
         table.inherits.append(partbl)
Exemple #22
0
    def apply(self, table):
        """Create a trigger for the table passed in.

        :param table: table on which the trigger will be created
        """
        newtrg = Trigger(schema=table.schema, table=table.name,
                         **self.__dict__)
        if newtrg.name.startswith('{{table_name}}'):
            newtrg.name = newtrg.name.replace(newtrg.name[:14], table.name)
        newtrg._table = table
        if not hasattr(table, 'triggers'):
            table.triggers = {}
        if hasattr(newtrg, 'procedure'):
            if newtrg.procedure.startswith('{{table_name}}'):
                newtrg.procedure = newtrg.procedure.replace(
                    newtrg.procedure[:14], table.name)
            (sch, fnc) = split_schema_obj(newtrg.procedure)
            if sch != table.schema:
                newtrg.procedure = "%s.%s" % (table.schema, fnc)
        table.triggers.update({newtrg.name: newtrg})
Exemple #23
0
    def apply(self, table):
        """Create a trigger for the table passed in.

        :param table: table on which the trigger will be created
        """
        newtrg = Trigger(schema=table.schema,
                         table=table.name,
                         **self.__dict__)
        if newtrg.name.startswith('{{table_name}}'):
            newtrg.name = newtrg.name.replace(newtrg.name[:14], table.name)
        newtrg._table = table
        if not hasattr(table, 'triggers'):
            table.triggers = {}
        if hasattr(newtrg, 'procedure'):
            if newtrg.procedure.startswith('{{table_name}}'):
                newtrg.procedure = newtrg.procedure.replace(
                    newtrg.procedure[:14], table.name)
            (sch, fnc) = split_schema_obj(newtrg.procedure)
            if sch != table.schema:
                newtrg.procedure = "%s.%s" % (table.schema, fnc)
        table.triggers.update({newtrg.name: newtrg})
Exemple #24
0
    def apply(self, table, augdb):
        """Apply audit columns to argument table.

        :param table: table to which columns/triggers will be added
        :param augdb: augment dictionaries
        """
        currdb = augdb.current
        sch = table.schema
        for col in self.columns:
            augdb.columns[col].apply(table)
        if hasattr(self, 'triggers'):
            for trg in self.triggers:
                augdb.triggers[trg].apply(table)
                for newtrg in table.triggers:
                    fncsig = table.triggers[newtrg].procedure
                    (sch, fnc) = split_schema_obj(fncsig, table.schema)
                    if (sch, fncsig) not in currdb.functions:
                        newfunc = augdb.functions[fnc].apply(
                            sch, augdb.columns.col_trans_tbl, augdb)
                        # add new function to the current db
                        augdb.add_func(sch, newfunc)
                        augdb.add_lang(newfunc.language)
Exemple #25
0
    def apply(self, table):
        """Create a trigger for the table passed in.

        :param table: table on which the trigger will be created
        """
        newtrg = Trigger(self.name, table.schema, table.name,
                         getattr(self, 'description', None),
                         self.procedure, self.timing, self.level, self.events)
        newtrg._iscfg = True
        if newtrg.name.startswith('{{table_name}}'):
            newtrg.name = newtrg.name.replace(newtrg.name[:14], table.name)
        newtrg._table = table
        if not hasattr(table, 'triggers'):
            table.triggers = {}
        if hasattr(newtrg, 'procedure'):
            if newtrg.procedure.startswith('{{table_name}}'):
                newtrg.procedure = newtrg.procedure.replace(
                    newtrg.procedure[:14], table.name)
            (sch, fnc) = split_schema_obj(newtrg.procedure)
            if sch != table.schema:
                newtrg.procedure = "%s.%s" % (table.schema, fnc)
        table.triggers.update({newtrg.name: newtrg})
Exemple #26
0
    def apply(self, table, augdb):
        """Apply audit columns to argument table.

        :param table: table to which columns/triggers will be added
        :param augdb: augment dictionaries
        """
        currdb = augdb.current
        sch = table.schema
        for col in self.columns:
            augdb.columns[col].apply(table)
        if hasattr(self, 'triggers'):
            for trg in self.triggers:
                augdb.triggers[trg].apply(table)
                for newtrg in table.triggers:
                    fncsig = table.triggers[newtrg].procedure
                    fnc = fncsig[:fncsig.find('(')]
                    (sch, fnc) = split_schema_obj(fnc)
                    if (sch, fncsig) not in currdb.functions:
                        newfunc = augdb.functions[fnc].apply(
                            sch, augdb.columns.col_trans_tbl, augdb)
                        # add new function to the current db
                        augdb.add_func(sch, newfunc)
                        augdb.add_lang(newfunc.language)
Exemple #27
0
    def link_refs(self, dbcolumns, dbconstrs, dbindexes, dbrules, dbtriggers):
        """Connect columns, constraints, etc. to their respective tables

        :param dbcolumns: dictionary of columns
        :param dbconstrs: dictionary of constraints
        :param dbindexes: dictionary of indexes
        :param dbrules: dictionary of rules
        :param dbtriggers: dictionary of triggers

        Links each list of table columns in `dbcolumns` to the
        corresponding table. Fills the `foreign_keys`,
        `unique_constraints`, `indexes` and `triggers` dictionaries
        for each table by traversing the `dbconstrs`, `dbindexes` and
        `dbtriggers` dictionaries, which are keyed by schema, table
        and constraint, index or trigger name.
        """
        for (sch, tbl) in dbcolumns:
            if (sch, tbl) in self:
                assert isinstance(self[(sch, tbl)], Table)
                self[(sch, tbl)].columns = dbcolumns[(sch, tbl)]
                for col in dbcolumns[(sch, tbl)]:
                    col._table = self[(sch, tbl)]
        for (sch, tbl) in self:
            table = self[(sch, tbl)]
            if isinstance(table, Sequence) and hasattr(table, 'owner_table'):
                if isinstance(table.owner_column, int):
                    table.owner_column = self[(sch, table.owner_table)]. \
                        column_names()[table.owner_column - 1]
            elif isinstance(table, Table) and hasattr(table, 'inherits'):
                for partbl in table.inherits:
                    (parsch, partbl) = split_schema_obj(partbl)
                    assert self[(parsch, partbl)]
                    parent = self[(parsch, partbl)]
                    if not hasattr(parent, '_descendants'):
                        parent._descendants = []
                    parent._descendants.append(table)
        for (sch, tbl, cns) in dbconstrs:
            constr = dbconstrs[(sch, tbl, cns)]
            if hasattr(constr, 'target'):
                continue
            assert self[(sch, tbl)]
            constr._table = table = self[(sch, tbl)]
            if isinstance(constr, CheckConstraint):
                if not hasattr(table, 'check_constraints'):
                    table.check_constraints = {}
                table.check_constraints.update({cns: constr})
            elif isinstance(constr, PrimaryKey):
                table.primary_key = constr
            elif isinstance(constr, ForeignKey):
                if not hasattr(table, 'foreign_keys'):
                    table.foreign_keys = {}
                # link referenced and referrer
                constr.references = self[(constr.ref_schema, constr.ref_table)]
                # TODO: there can be more than one
                self[(constr.ref_schema, constr.ref_table)]._referred_by = \
                    constr
                table.foreign_keys.update({cns: constr})
            elif isinstance(constr, UniqueConstraint):
                if not hasattr(table, 'unique_constraints'):
                    table.unique_constraints = {}
                table.unique_constraints.update({cns: constr})

        def link_one(targdict, schema, tbl, objkey, objtype):
            table = self[(schema, tbl)]
            if not hasattr(table, objtype):
                setattr(table, objtype, {})
            objdict = getattr(table, objtype)
            objdict.update({objkey: targdict[(schema, tbl, objkey)]})

        for (sch, tbl, idx) in dbindexes:
            link_one(dbindexes, sch, tbl, idx, 'indexes')
        for (sch, tbl, rul) in dbrules:
            link_one(dbrules, sch, tbl, rul, 'rules')
            dbrules[(sch, tbl, rul)]._table = self[(sch, tbl)]
        for (sch, tbl, trg) in dbtriggers:
            link_one(dbtriggers, sch, tbl, trg, 'triggers')
            dbtriggers[(sch, tbl, trg)]._table = self[(sch, tbl)]
Exemple #28
0
    def link_refs(self, dbcolumns, dbconstrs, dbindexes, dbrules, dbtriggers):
        """Connect columns, constraints, etc. to their respective tables

        :param dbcolumns: dictionary of columns
        :param dbconstrs: dictionary of constraints
        :param dbindexes: dictionary of indexes
        :param dbrules: dictionary of rules
        :param dbtriggers: dictionary of triggers

        Links each list of table columns in `dbcolumns` to the
        corresponding table. Fills the `foreign_keys`,
        `unique_constraints`, `indexes` and `triggers` dictionaries
        for each table by traversing the `dbconstrs`, `dbindexes` and
        `dbtriggers` dictionaries, which are keyed by schema, table
        and constraint, index or trigger name.
        """
        for (sch, tbl) in dbcolumns.keys():
            if (sch, tbl) in self:
                assert isinstance(self[(sch, tbl)], Table)
                self[(sch, tbl)].columns = dbcolumns[(sch, tbl)]
                for col in dbcolumns[(sch, tbl)]:
                    col._table = self[(sch, tbl)]
        for (sch, tbl) in self.keys():
            table = self[(sch, tbl)]
            if isinstance(table, Sequence) and hasattr(table, 'owner_table'):
                if isinstance(table.owner_column, int):
                    table.owner_column = self[(sch, table.owner_table)]. \
                        column_names()[table.owner_column - 1]
            elif isinstance(table, Table) and hasattr(table, 'inherits'):
                for partbl in table.inherits:
                    (parsch, partbl) = split_schema_obj(partbl)
                    assert self[(parsch, partbl)]
                    parent = self[(parsch, partbl)]
                    if not hasattr(parent, 'descendants'):
                        parent.descendants = []
                    parent.descendants.append(table)
        for (sch, tbl, cns) in dbconstrs.keys():
            constr = dbconstrs[(sch, tbl, cns)]
            if hasattr(constr, 'target'):
                continue
            assert self[(sch, tbl)]
            constr._table = table = self[(sch, tbl)]
            if isinstance(constr, CheckConstraint):
                if not hasattr(table, 'check_constraints'):
                    table.check_constraints = {}
                table.check_constraints.update({cns: constr})
            elif isinstance(constr, PrimaryKey):
                table.primary_key = constr
            elif isinstance(constr, ForeignKey):
                if not hasattr(table, 'foreign_keys'):
                    table.foreign_keys = {}
                # link referenced and referrer
                constr.references = self[(constr.ref_schema, constr.ref_table)]
                # TODO: there can be more than one
                self[(constr.ref_schema, constr.ref_table)].referred_by = \
                    constr
                table.foreign_keys.update({cns: constr})
            elif isinstance(constr, UniqueConstraint):
                if not hasattr(table, 'unique_constraints'):
                    table.unique_constraints = {}
                table.unique_constraints.update({cns: constr})
        for (sch, tbl, idx) in dbindexes.keys():
            assert self[(sch, tbl)]
            table = self[(sch, tbl)]
            if not hasattr(table, 'indexes'):
                table.indexes = {}
            table.indexes.update({idx: dbindexes[(sch, tbl, idx)]})
        for (sch, tbl, rul) in dbrules.keys():
            assert self[(sch, tbl)]
            table = self[(sch, tbl)]
            if not hasattr(table, 'rules'):
                table.rules = {}
            table.rules.update({rul: dbrules[(sch, tbl, rul)]})
            dbrules[(sch, tbl, rul)]._table = self[(sch, tbl)]
        for (sch, tbl, trg) in dbtriggers.keys():
            assert self[(sch, tbl)]
            table = self[(sch, tbl)]
            if not hasattr(table, 'triggers'):
                table.triggers = {}
            table.triggers.update({trg: dbtriggers[(sch, tbl, trg)]})
            dbtriggers[(sch, tbl, trg)]._table = self[(sch, tbl)]
Exemple #29
0
    def link_refs(self, dbcolumns, dbconstrs, dbfuncs):
        """Connect various objects to their corresponding types or domains

        :param dbcolumns: dictionary of columns
        :param dbconstrs: dictionary of constraints
        :param dbfuncs: dictionary of functions

        Fills the `check_constraints` dictionaries for each domain by
        traversing the `dbconstrs` dictionary. Fills the attributes
        list for composite types. Fills the dependent functions
        dictionary for base types.
        """
        for (sch, typ) in dbcolumns:
            if (sch, typ) in self:
                assert isinstance(self[(sch, typ)], Composite)
                self[(sch, typ)].attributes = dbcolumns[(sch, typ)]
                for attr in dbcolumns[(sch, typ)]:
                    attr._type = self[(sch, typ)]
        for (sch, typ, cns) in dbconstrs:
            constr = dbconstrs[(sch, typ, cns)]
            if not hasattr(constr, 'target') or constr.target != 'd':
                continue
            assert self[(sch, typ)]
            dbtype = self[(sch, typ)]
            if isinstance(constr, CheckConstraint):
                if not hasattr(dbtype, 'check_constraints'):
                    dbtype.check_constraints = {}
                dbtype.check_constraints.update({cns: constr})
        for (sch, typ) in self:
            dbtype = self[(sch, typ)]
            if isinstance(dbtype, BaseType):
                if not hasattr(dbtype, 'dep_funcs'):
                    dbtype.dep_funcs = {}
                (sch, infnc) = split_schema_obj(dbtype.input, sch)
                args = 'cstring'
                if not (sch, infnc, args) in dbfuncs:
                    args = 'cstring, oid, integer'
                func = dbfuncs[(sch, infnc, args)]
                dbtype.dep_funcs.update({'input': func})
                func._dep_type = dbtype
                (sch, outfnc) = split_schema_obj(dbtype.output, sch)
                func = dbfuncs[(sch, outfnc, dbtype.qualname())]
                dbtype.dep_funcs.update({'output': func})
                func._dep_type = dbtype
                for attr in OPT_FUNCS:
                    if hasattr(dbtype, attr):
                        (sch, fnc) = split_schema_obj(getattr(dbtype, attr),
                                                      sch)
                        if attr == 'receive':
                            arg = 'internal'
                        elif attr == 'send':
                            arg = dbtype.qualname()
                        elif attr == 'typmod_in':
                            arg = 'cstring[]'
                        elif attr == 'typmod_out':
                            arg = 'integer'
                        elif attr == 'analyze':
                            arg = 'internal'
                        func = dbfuncs[(sch, fnc, arg)]
                        dbtype.dep_funcs.update({attr: func})
                        func._dep_type = dbtype
Exemple #30
0
 def get_implied_deps(self, db):
     deps = super(TSConfiguration, self).get_implied_deps(db)
     deps.add(db.tsparsers[split_schema_obj(self.parser, self.schema)])
     return deps
Exemple #31
0
    def link_refs(self, dbtypes, dbtables, dbfunctions, dbopers, dbopfams,
                  dbopcls, dbconvs, dbtsconfigs, dbtsdicts, dbtspars,
                  dbtstmpls, dbftables):
        """Connect types, tables and functions to their respective schemas

        :param dbtypes: dictionary of types and domains
        :param dbtables: dictionary of tables, sequences and views
        :param dbfunctions: dictionary of functions
        :param dbopers: dictionary of operators
        :param dbopfams: dictionary of operator families
        :param dbopcls: dictionary of operator classes
        :param dbconvs: dictionary of conversions
        :param dbtsconfigs: dictionary of text search configurations
        :param dbtsdicts: dictionary of text search dictionaries
        :param dbtspars: dictionary of text search parsers
        :param dbtstmpls: dictionary of text search templates
        :param dbftables: dictionary of foreign tables

        Fills in the `domains` dictionary for each schema by
        traversing the `dbtypes` dictionary.  Fills in the `tables`,
        `sequences`, `views` dictionaries for each schema by
        traversing the `dbtables` dictionary. Fills in the `functions`
        dictionary by traversing the `dbfunctions` dictionary.
        """
        for (sch, typ) in list(dbtypes.keys()):
            dbtype = dbtypes[(sch, typ)]
            assert self[sch]
            schema = self[sch]
            if isinstance(dbtype, Domain):
                if not hasattr(schema, 'domains'):
                    schema.domains = {}
                schema.domains.update({typ: dbtypes[(sch, typ)]})
            elif isinstance(dbtype, Enum) or isinstance(dbtype, Composite) \
                    or isinstance(dbtype, BaseType):
                if not hasattr(schema, 'types'):
                    schema.types = {}
                schema.types.update({typ: dbtypes[(sch, typ)]})
        for (sch, tbl) in list(dbtables.keys()):
            table = dbtables[(sch, tbl)]
            assert self[sch]
            schema = self[sch]
            if isinstance(table, Table):
                if not hasattr(schema, 'tables'):
                    schema.tables = {}
                schema.tables.update({tbl: table})
            elif isinstance(table, Sequence):
                if not hasattr(schema, 'sequences'):
                    schema.sequences = {}
                schema.sequences.update({tbl: table})
            elif isinstance(table, View):
                if not hasattr(schema, 'views'):
                    schema.views = {}
                schema.views.update({tbl: table})
        for (sch, fnc, arg) in list(dbfunctions.keys()):
            func = dbfunctions[(sch, fnc, arg)]
            assert self[sch]
            schema = self[sch]
            if not hasattr(schema, 'functions'):
                schema.functions = {}
            schema.functions.update({(fnc, arg): func})
            if hasattr(func, 'returns'):
                rettype = func.returns
                if rettype.upper().startswith("SETOF "):
                    rettype = rettype[6:]
                (retsch, rettyp) = split_schema_obj(rettype, sch)
                if (retsch, rettyp) in list(dbtables.keys()):
                    deptbl = dbtables[(retsch, rettyp)]
                    if not hasattr(func, 'dependent_table'):
                        func.dependent_table = deptbl
                    if not hasattr(deptbl, 'dependent_funcs'):
                        deptbl.dependent_funcs = []
                    deptbl.dependent_funcs.append(func)
        for (sch, opr, lft, rgt) in list(dbopers.keys()):
            oper = dbopers[(sch, opr, lft, rgt)]
            assert self[sch]
            schema = self[sch]
            if not hasattr(schema, 'operators'):
                schema.operators = {}
            schema.operators.update({(opr, lft, rgt): oper})
        for (sch, opc, idx) in list(dbopcls.keys()):
            opcl = dbopcls[(sch, opc, idx)]
            assert self[sch]
            schema = self[sch]
            if not hasattr(schema, 'operclasses'):
                schema.operclasses = {}
            schema.operclasses.update({(opc, idx): opcl})
        for (sch, opf, idx) in list(dbopfams.keys()):
            opfam = dbopfams[(sch, opf, idx)]
            assert self[sch]
            schema = self[sch]
            if not hasattr(schema, 'operfams'):
                schema.operfams = {}
            schema.operfams.update({(opf, idx): opfam})
        for (sch, cnv) in list(dbconvs.keys()):
            conv = dbconvs[(sch, cnv)]
            assert self[sch]
            schema = self[sch]
            if not hasattr(schema, 'conversions'):
                schema.conversions = {}
            schema.conversions.update({cnv: conv})
        for (sch, tsc) in list(dbtsconfigs.keys()):
            tscfg = dbtsconfigs[(sch, tsc)]
            assert self[sch]
            schema = self[sch]
            if not hasattr(schema, 'tsconfigs'):
                schema.tsconfigs = {}
            schema.tsconfigs.update({tsc: tscfg})
        for (sch, tsd) in list(dbtsdicts.keys()):
            tsdict = dbtsdicts[(sch, tsd)]
            assert self[sch]
            schema = self[sch]
            if not hasattr(schema, 'tsdicts'):
                schema.tsdicts = {}
            schema.tsdicts.update({tsd: tsdict})
        for (sch, tsp) in list(dbtspars.keys()):
            tspar = dbtspars[(sch, tsp)]
            assert self[sch]
            schema = self[sch]
            if not hasattr(schema, 'tsparsers'):
                schema.tsparsers = {}
            schema.tsparsers.update({tsp: tspar})
        for (sch, tst) in list(dbtstmpls.keys()):
            tstmpl = dbtstmpls[(sch, tst)]
            assert self[sch]
            schema = self[sch]
            if not hasattr(schema, 'tstempls'):
                schema.tstempls = {}
            schema.tstempls.update({tst: tstmpl})
        for (sch, ftb) in list(dbftables.keys()):
            ftbl = dbftables[(sch, ftb)]
            assert self[sch]
            schema = self[sch]
            if not hasattr(schema, 'ftables'):
                schema.ftables = {}
            schema.ftables.update({ftb: ftbl})
Exemple #32
0
 def _from_catalog(self):
     """Initialize the dictionary of indexes by querying the catalogs"""
     for index in self.fetch():
         index.unqualify()
         sch, tbl, idx = index.key()
         sch, tbl = split_schema_obj('%s.%s' % (sch, tbl))
         keydefs, _, _ = index.defn.partition(' WHERE ')
         _, _, keydefs = keydefs.partition(' USING ')
         keydefs = keydefs[keydefs.find(' (') + 2:-1]
         # split expressions (result of pg_get_expr)
         if hasattr(index, 'keyexprs'):
             keyexprs = split_exprs(index.keyexprs)
             del index.keyexprs
         # parse the keys
         i = 0
         rest = keydefs
         index.keys = []
         for col in index.keycols.split():
             keyopts = []
             extra = {}
             if col == '0':
                 expr = keyexprs[i]
                 if rest and rest[0] == '(':
                     expr = '(' + expr + ')'
                 assert(rest.startswith(expr))
                 key = expr
                 extra = {'type': 'expression'}
                 explen = len(expr)
                 loc = rest[explen:].find(',')
                 if loc == 0:
                     keyopts = []
                     rest = rest[explen + 1:].lstrip()
                 elif loc == -1:
                     keyopts = rest[explen:].split()
                     rest = ''
                 else:
                     keyopts = rest[explen:explen + loc].split()
                     rest = rest[explen + loc + 1:].lstrip()
                 i += 1
             else:
                 loc = rest.find(',')
                 key = rest[:loc] if loc != -1 else rest.lstrip()
                 keyopts = key.split()[1:]
                 key = key.split()[0]
                 rest = rest[loc + 1:]
             rest = rest.lstrip()
             skipnext = False
             for j, opt in enumerate(keyopts):
                 if skipnext:
                     skipnext = False
                     continue
                 if opt.upper() not in ['COLLATE', 'ASC', 'DESC', 'NULLS',
                                        'FIRST', 'LAST']:
                     extra.update(opclass=opt)
                     continue
                 elif opt == 'COLLATE':
                     extra.update(collation=keyopts[j + 1])
                     skipnext = True
                 elif opt == 'NULLS':
                     extra.update(nulls=keyopts[j + 1].lower())
                     skipnext = True
                 elif opt == 'DESC':
                     extra.update(order='desc')
             if extra:
                 key = {key: extra}
             index.keys.append(key)
         del index.defn, index.keycols
         self[(sch, tbl, idx)] = index
Exemple #33
0
    def link_refs(self, db, datacopy):
        """Connect various schema objects to their respective schemas

        :param db: dictionary of dictionaries of all objects
        :param datacopy: dictionary of data copying info
        """
        def link_one(targdict, objtype, objkeys, subtype=None):
            schema = self[objkeys[0]]
            if subtype is None:
                subtype = objtype
            if not hasattr(schema, subtype):
                setattr(schema, subtype, {})
            objdict = getattr(schema, subtype)
            key = objkeys[1] if len(objkeys) == 2 else objkeys[1:]
            objdict.update({key: targdict[objkeys]})

        targ = db.types
        for keys in targ:
            dbtype = targ[keys]
            if isinstance(dbtype, Domain):
                link_one(targ, 'types', keys, 'domains')
            elif isinstance(dbtype, Enum) or isinstance(dbtype, Composite) \
                    or isinstance(dbtype, BaseType):
                link_one(targ, 'types', keys)
        targ = db.tables
        for keys in targ:
            table = targ[keys]
            type_ = 'tables'
            if isinstance(table, Table):
                link_one(targ, type_, keys)
            elif isinstance(table, Sequence):
                link_one(targ, type_, keys, 'sequences')
            elif isinstance(table, MaterializedView):
                link_one(targ, type_, keys, 'matviews')
            elif isinstance(table, View):
                link_one(targ, type_, keys, 'views')
        targ = db.functions
        for keys in targ:
            func = targ[keys]
            link_one(targ, 'functions', keys)
            if hasattr(func, 'returns'):
                rettype = func.returns
                if rettype.upper().startswith("SETOF "):
                    rettype = rettype[6:]
                (retsch, rettyp) = split_schema_obj(rettype, keys[0])
                if (retsch, rettyp) in db.tables:
                    deptbl = db.tables[(retsch, rettyp)]
                    if not hasattr(func, 'dependent_table'):
                        func.dependent_table = deptbl
                    if not hasattr(deptbl, 'dependent_funcs'):
                        deptbl.dependent_funcs = []
                    deptbl.dependent_funcs.append(func)
        for objtype in [
                'operators', 'operclasses', 'operfams', 'conversions',
                'tsconfigs', 'tsdicts', 'tsparsers', 'tstempls', 'ftables',
                'collations'
        ]:
            targ = getattr(db, objtype)
            for keys in targ:
                link_one(targ, objtype, keys)
        for key in datacopy:
            if not key.startswith('schema '):
                raise KeyError("Unrecognized object type: %s" % key)
            schema = self[key[7:]]
            if not hasattr(schema, 'datacopy'):
                schema.datacopy = []
            for tbl in datacopy[key]:
                if hasattr(schema, 'tables') and tbl in schema.tables:
                    schema.datacopy.append(tbl)
Exemple #34
0
 def find(self, obj, meth):
     schema, name = split_schema_obj(obj)
     return self.get((schema, name, meth))
Exemple #35
0
 def get_implied_deps(self, db):
     deps = super(EventTrigger, self).get_implied_deps(db)
     sch, fnc = split_schema_obj(self.procedure)
     deps.add(db.functions[(sch, fnc[:-2], '')])
     return deps
Exemple #36
0
 def _from_catalog(self):
     """Initialize the dictionary of indexes by querying the catalogs"""
     for index in self.fetch():
         index.unqualify()
         oid = index.oid
         sch, tbl, idx = index.key()
         sch, tbl = split_schema_obj('%s.%s' % (sch, tbl))
         keydefs, _, _ = index.defn.partition(' WHERE ')
         _, _, keydefs = keydefs.partition(' USING ')
         keydefs = keydefs[keydefs.find(' (') + 2:-1]
         # split expressions (result of pg_get_expr)
         if hasattr(index, 'keyexprs'):
             keyexprs = split_exprs(index.keyexprs)
             del index.keyexprs
         # parse the keys
         i = 0
         rest = keydefs
         index.keys = []
         for col in index.keycols.split():
             keyopts = []
             extra = {}
             if col == '0':
                 expr = keyexprs[i]
                 if rest and rest[0] == '(':
                     expr = '(' + expr + ')'
                 assert(rest.startswith(expr))
                 key = expr
                 extra = {'type': 'expression'}
                 explen = len(expr)
                 loc = rest[explen:].find(',')
                 if loc == 0:
                     keyopts = []
                     rest = rest[explen + 1:].lstrip()
                 elif loc == -1:
                     keyopts = rest[explen:].split()
                     rest = ''
                 else:
                     keyopts = rest[explen:explen + loc].split()
                     rest = rest[explen + loc + 1:].lstrip()
                 i += 1
             else:
                 loc = rest.find(',')
                 key = rest[:loc] if loc != -1 else rest.lstrip()
                 keyopts = key.split()[1:]
                 key = key.split()[0]
                 rest = rest[loc + 1:]
             rest = rest.lstrip()
             skipnext = False
             for j, opt in enumerate(keyopts):
                 if skipnext:
                     skipnext = False
                     continue
                 if opt.upper() not in ['COLLATE', 'ASC', 'DESC', 'NULLS',
                                        'FIRST', 'LAST']:
                     extra.update(opclass=opt)
                     continue
                 elif opt == 'COLLATE':
                     extra.update(collation=keyopts[j + 1])
                     skipnext = True
                 elif opt == 'NULLS':
                     extra.update(nulls=keyopts[j + 1].lower())
                     skipnext = True
                 elif opt == 'DESC':
                     extra.update(order='desc')
             if extra:
                 key = {key: extra}
             index.keys.append(key)
         del index.defn, index.keycols
         self.by_oid[oid] = self[(sch, tbl, idx)] = index
Exemple #37
0
    def link_refs(self, db, datacopy):
        """Connect various schema objects to their respective schemas

        :param db: dictionary of dictionaries of all objects
        :param datacopy: dictionary of data copying info
        """
        def link_one(targdict, objtype, objkeys, subtype=None):
            schema = self[objkeys[0]]
            if subtype is None:
                subtype = objtype
            if not hasattr(schema, subtype):
                setattr(schema, subtype, {})
            objdict = getattr(schema, subtype)
            key = objkeys[1] if len(objkeys) == 2 else objkeys[1:]
            objdict.update({key: targdict[objkeys]})

        targ = db.types
        for keys in targ:
            dbtype = targ[keys]
            if isinstance(dbtype, Domain):
                link_one(targ, 'types', keys, 'domains')
            elif isinstance(dbtype, Enum) or isinstance(dbtype, Composite) \
                    or isinstance(dbtype, BaseType):
                link_one(targ, 'types', keys)
        targ = db.tables
        for keys in targ:
            table = targ[keys]
            type_ = 'tables'
            if isinstance(table, Table):
                link_one(targ, type_, keys)
            elif isinstance(table, Sequence):
                link_one(targ, type_, keys, 'sequences')
            elif isinstance(table, MaterializedView):
                link_one(targ, type_, keys, 'matviews')
            elif isinstance(table, View):
                link_one(targ, type_, keys, 'views')
        targ = db.functions
        for keys in targ:
            func = targ[keys]
            link_one(targ, 'functions', keys)
            if hasattr(func, 'returns'):
                rettype = func.returns
                if rettype.upper().startswith("SETOF "):
                    rettype = rettype[6:]
                (retsch, rettyp) = split_schema_obj(rettype, keys[0])
                if (retsch, rettyp) in db.tables:
                    deptbl = db.tables[(retsch, rettyp)]
                    if not hasattr(func, 'dependent_table'):
                        func.dependent_table = deptbl
                    if not hasattr(deptbl, 'dependent_funcs'):
                        deptbl.dependent_funcs = []
                    deptbl.dependent_funcs.append(func)
        for objtype in ['operators', 'operclasses', 'operfams', 'conversions',
                        'tsconfigs', 'tsdicts', 'tsparsers', 'tstempls',
                        'ftables', 'collations']:
            targ = getattr(db, objtype)
            for keys in targ:
                link_one(targ, objtype, keys)
        for key in datacopy:
            if not key.startswith('schema '):
                raise KeyError("Unrecognized object type: %s" % key)
            schema = self[key[7:]]
            if not hasattr(schema, 'datacopy'):
                schema.datacopy = []
            for tbl in datacopy[key]:
                if hasattr(schema, 'tables') and tbl in schema.tables:
                    schema.datacopy.append(tbl)
Exemple #38
0
 def _from_catalog(self):
     """Initialize the dictionary of indexes by querying the catalogs"""
     for index in self.fetch():
         index.unqualify()
         sch, tbl, idx = index.key()
         sch, tbl = split_schema_obj('%s.%s' % (sch, tbl))
         keydefs = index.defn[index.defn.find(' USING ') + 7:]
         keydefs = keydefs[keydefs.find(' (') + 2:-1]
         # split expressions (result of pg_get_expr)
         keyexprs = []
         if hasattr(index, 'keyexprs'):
             rest = index.keyexprs
             del index.keyexprs
             while len(rest):
                 loc = rest.find(',')
                 expr = rest[:loc]
                 cntopen = expr.count('(')
                 cntcls = expr.count(')')
                 if cntopen == cntcls:
                     keyexprs.append(rest[:loc])
                     rest = rest[loc + 1:].lstrip()
                 elif cntcls < cntopen:
                     loc = rest[loc + 1:].find(',')
                     if loc == -1:
                         keyexprs.append(rest)
                         rest = ''
                     else:
                         loc2 = rest[loc + 1:].find(',')
                         loccls = rest[loc + 1:].find(')')
                         if loc2 != -1 and loccls < loc2:
                             keyexprs.append(rest[:loc + loccls + 2])
                             rest = rest[loc + loccls + 3:].lstrip()
         # parse the keys
         i = 0
         rest = keydefs
         index.keys = []
         for col in index.keycols.split():
             keyopts = []
             extra = {}
             if col == '0':
                 expr = keyexprs[i]
                 assert(rest.startswith(expr))
                 key = expr
                 extra = {'type': 'expression'}
                 explen = len(expr)
                 loc = rest[explen:].find(',')
                 if loc == 0:
                     keyopts = []
                     rest = rest[explen + 1:].lstrip()
                 elif loc == -1:
                     keyopts = rest[explen:].split()
                     rest = ''
                 else:
                     keyopts = rest[explen:explen + loc].split()
                     rest = rest[explen + loc + 1:].lstrip()
                 i += 1
             else:
                 loc = rest.find(',')
                 key = rest[:loc] if loc != -1 else rest.lstrip()
                 keyopts = key.split()[1:]
                 key = key.split()[0]
                 rest = rest[loc + 1:]
             rest = rest.lstrip()
             for j, opt in enumerate(keyopts):
                 if opt.upper() not in ['ASC', 'DESC', 'NULLS',
                                        'FIRST', 'LAST']:
                     extra.update(opclass=opt)
                     continue
                 elif opt == 'NULLS':
                     extra.update(nulls=keyopts[j + 1].lower())
                 elif opt == 'DESC':
                     extra.update(order='desc')
                 else:
                     continue
             if extra:
                 key = {key: extra}
             index.keys.append(key)
         del index.defn, index.keycols
         self[(sch, tbl, idx)] = index
Exemple #39
0
    def diff_map(self, intables):
        """Generate SQL to transform existing tables and sequences

        :param intables: a YAML map defining the new tables/sequences
        :return: list of SQL statements

        Compares the existing table/sequence definitions, as fetched
        from the catalogs, to the input map and generates SQL
        statements to transform the tables/sequences accordingly.
        """
        stmts = []
        # first pass: sequences owned by a table
        for (sch, seq) in intables.keys():
            inseq = intables[(sch, seq)]
            if not isinstance(inseq, Sequence) or \
                    not hasattr(inseq, 'owner_table'):
                continue
            if (sch, seq) not in self:
                if hasattr(inseq, 'oldname'):
                    stmts.append(self._rename(inseq, "sequence"))
                else:
                    # create new sequence
                    stmts.append(inseq.create())

        # check input tables
        inhstack = []
        for (sch, tbl) in intables.keys():
            intable = intables[(sch, tbl)]
            if not isinstance(intable, Table):
                continue
            # does it exist in the database?
            if (sch, tbl) not in self:
                if not hasattr(intable, 'oldname'):
                    # create new table
                    if hasattr(intable, 'inherits'):
                        inhstack.append(intable)
                    else:
                        stmts.append(intable.create())
                else:
                    stmts.append(self._rename(intable, "table"))
        while len(inhstack):
            intable = inhstack.pop()
            createit = True
            for partbl in intable.inherits:
                if intables[split_schema_obj(partbl)] in inhstack:
                    createit = False
            if createit:
                stmts.append(intable.create())
            else:
                inhstack.insert(0, intable)

        # check input views
        for (sch, tbl) in intables.keys():
            intable = intables[(sch, tbl)]
            if not isinstance(intable, View):
                continue
            # does it exist in the database?
            if (sch, tbl) not in self:
                if hasattr(intable, 'oldname'):
                    stmts.append(self._rename(intable, "view"))
                else:
                    # create new view
                    stmts.append(intable.create())

        # second pass: input sequences not owned by tables
        for (sch, seq) in intables.keys():
            inseq = intables[(sch, seq)]
            if not isinstance(inseq, Sequence):
                continue
            # does it exist in the database?
            if (sch, seq) not in self:
                if hasattr(inseq, 'oldname'):
                    stmts.append(self._rename(inseq, "sequence"))
                elif hasattr(inseq, 'owner_table'):
                    stmts.append(inseq.add_owner())
                else:
                    # create new sequence
                    stmts.append(inseq.create())

        # check database tables, sequences and views
        for (sch, tbl) in self.keys():
            table = self[(sch, tbl)]
            # if missing, mark it for dropping
            if (sch, tbl) not in intables:
                table.dropped = False
            else:
                # check table/sequence/view objects
                stmts.append(table.diff_map(intables[(sch, tbl)]))

        # now drop the marked tables
        for (sch, tbl) in self.keys():
            table = self[(sch, tbl)]
            if isinstance(table, Sequence) and hasattr(table, 'owner_table'):
                continue
            if hasattr(table, 'dropped') and not table.dropped:
                # first, drop all foreign keys
                if hasattr(table, 'foreign_keys'):
                    for fgn in table.foreign_keys:
                        stmts.append(table.foreign_keys[fgn].drop())
                # and drop the triggers
                if hasattr(table, 'triggers'):
                    for trg in table.triggers:
                        stmts.append(table.triggers[trg].drop())
                if hasattr(table, 'rules'):
                    for rul in table.rules:
                        stmts.append(table.rules[rul].drop())
                # drop views
                if isinstance(table, View):
                    stmts.append(table.drop())

        inhstack = []
        for (sch, tbl) in self.keys():
            table = self[(sch, tbl)]
            if (isinstance(table, Sequence) \
                    and (hasattr(table, 'owner_table') \
                             or hasattr(table, 'dependent_table'))) \
                             or isinstance(table, View):
                continue
            if hasattr(table, 'dropped') and not table.dropped:
                # next, drop other subordinate objects
                if hasattr(table, 'check_constraints'):
                    for chk in table.check_constraints:
                        stmts.append(table.check_constraints[chk].drop())
                if hasattr(table, 'unique_constraints'):
                    for unq in table.unique_constraints:
                        stmts.append(table.unique_constraints[unq].drop())
                if hasattr(table, 'indexes'):
                    for idx in table.indexes:
                        stmts.append(table.indexes[idx].drop())
                if hasattr(table, 'rules'):
                    for rul in table.rules:
                        stmts.append(table.rules[rul].drop())
                if hasattr(table, 'primary_key'):
                    # TODO there can be more than one referred_by
                    if hasattr(table, 'referred_by'):
                        stmts.append(table.referred_by.drop())
                    stmts.append(table.primary_key.drop())
                # finally, drop the table itself
                if hasattr(table, 'descendants'):
                    inhstack.append(table)
                else:
                    stmts.append(table.drop())
        while len(inhstack):
            table = inhstack.pop()
            dropit = True
            for childtbl in table.descendants:
                if self[(childtbl.schema, childtbl.name)] in inhstack:
                    dropit = False
            if dropit:
                stmts.append(table.drop())
            else:
                inhstack.insert(0, table)
        for (sch, tbl) in self.keys():
            table = self[(sch, tbl)]
            if isinstance(table, Sequence) \
                    and hasattr(table, 'dependent_table') \
                    and hasattr(table, 'dropped') and not table.dropped:
                stmts.append(table.drop())

        # last pass to deal with nextval DEFAULTs
        for (sch, tbl) in intables.keys():
            intable = intables[(sch, tbl)]
            if not isinstance(intable, Table):
                continue
            if (sch, tbl) not in self:
                for col in intable.columns:
                    if hasattr(col, 'default') \
                            and col.default.startswith('nextval'):
                        stmts.append(col.set_sequence_default())

        return stmts
Exemple #40
0
    def diff_map(self, intables):
        """Generate SQL to transform existing tables and sequences

        :param intables: a YAML map defining the new tables/sequences
        :return: list of SQL statements

        Compares the existing table/sequence definitions, as fetched
        from the catalogs, to the input map and generates SQL
        statements to transform the tables/sequences accordingly.
        """
        stmts = []
        # first pass: sequences owned by a table
        for (sch, seq) in intables:
            inseq = intables[(sch, seq)]
            if not isinstance(inseq, Sequence) or \
                    not hasattr(inseq, 'owner_table'):
                continue
            if (sch, seq) not in self:
                if hasattr(inseq, 'oldname'):
                    stmts.append(self._rename(inseq, "sequence"))
                else:
                    # create new sequence
                    stmts.append(inseq.create())

        # check input tables
        inhstack = []
        for (sch, tbl) in intables:
            intable = intables[(sch, tbl)]
            if not isinstance(intable, Table):
                continue
            # does it exist in the database?
            if (sch, tbl) not in self:
                if not hasattr(intable, 'oldname'):
                    # create new table
                    if hasattr(intable, 'inherits'):
                        inhstack.append(intable)
                    else:
                        stmts.append(intable.create())
                else:
                    stmts.append(self._rename(intable, "table"))
        while len(inhstack):
            intable = inhstack.pop()
            createit = True
            for partbl in intable.inherits:
                if intables[split_schema_obj(partbl)] in inhstack:
                    createit = False
            if createit:
                stmts.append(intable.create())
            else:
                inhstack.insert(0, intable)

        # check input views
        for (sch, tbl) in intables:
            intable = intables[(sch, tbl)]
            if not isinstance(intable, View):
                continue
            # does it exist in the database?
            if (sch, tbl) not in self:
                if hasattr(intable, 'oldname'):
                    stmts.append(self._rename(intable, "view"))
                else:
                    # create new view
                    stmts.append(intable.create())

        # second pass: input sequences not owned by tables
        for (sch, seq) in intables:
            inseq = intables[(sch, seq)]
            if not isinstance(inseq, Sequence):
                continue
            # does it exist in the database?
            if (sch, seq) not in self:
                if hasattr(inseq, 'oldname'):
                    stmts.append(self._rename(inseq, "sequence"))
                elif hasattr(inseq, 'owner_table'):
                    stmts.append(inseq.add_owner())
                else:
                    # create new sequence
                    stmts.append(inseq.create())

        # check database tables, sequences and views
        for (sch, tbl) in self:
            table = self[(sch, tbl)]
            # if missing, mark it for dropping
            if (sch, tbl) not in intables:
                table.dropped = False
            else:
                # check table/sequence/view objects
                stmts.append(table.diff_map(intables[(sch, tbl)]))

        # now drop the marked tables
        for (sch, tbl) in self:
            table = self[(sch, tbl)]
            if isinstance(table, Sequence) and hasattr(table, 'owner_table'):
                continue
            if hasattr(table, 'dropped') and not table.dropped:
                # first, drop all foreign keys
                if hasattr(table, 'foreign_keys'):
                    for fgn in table.foreign_keys:
                        stmts.append(table.foreign_keys[fgn].drop())
                # and drop the triggers
                if hasattr(table, 'triggers'):
                    for trg in table.triggers:
                        stmts.append(table.triggers[trg].drop())
                if hasattr(table, 'rules'):
                    for rul in table.rules:
                        stmts.append(table.rules[rul].drop())
                # drop views
                if isinstance(table, View):
                    stmts.append(table.drop())

        inhstack = []
        for (sch, tbl) in self:
            table = self[(sch, tbl)]
            if (isinstance(table, Sequence) and
                (hasattr(table, 'owner_table') or
                 hasattr(table, 'dependent_table'))) or \
                    isinstance(table, View):
                continue
            if hasattr(table, 'dropped') and not table.dropped:
                # next, drop other subordinate objects
                if hasattr(table, 'check_constraints'):
                    for chk in table.check_constraints:
                        stmts.append(table.check_constraints[chk].drop())
                if hasattr(table, 'unique_constraints'):
                    for unq in table.unique_constraints:
                        stmts.append(table.unique_constraints[unq].drop())
                if hasattr(table, 'indexes'):
                    for idx in table.indexes:
                        stmts.append(table.indexes[idx].drop())
                if hasattr(table, 'rules'):
                    for rul in table.rules:
                        stmts.append(table.rules[rul].drop())
                if hasattr(table, 'primary_key'):
                    # TODO there can be more than one referred_by
                    if hasattr(table, 'referred_by'):
                        stmts.append(table.referred_by.drop())
                    stmts.append(table.primary_key.drop())
                # finally, drop the table itself
                if hasattr(table, 'descendants'):
                    inhstack.append(table)
                else:
                    stmts.append(table.drop())
        while len(inhstack):
            table = inhstack.pop()
            dropit = True
            for childtbl in table.descendants:
                if self[(childtbl.schema, childtbl.name)] in inhstack:
                    dropit = False
            if dropit:
                stmts.append(table.drop())
            else:
                inhstack.insert(0, table)
        for (sch, tbl) in self:
            table = self[(sch, tbl)]
            if isinstance(table, Sequence) \
                    and hasattr(table, 'dependent_table') \
                    and hasattr(table, 'dropped') and not table.dropped:
                stmts.append(table.drop())

        # last pass to deal with nextval DEFAULTs
        for (sch, tbl) in intables:
            intable = intables[(sch, tbl)]
            if not isinstance(intable, Table):
                continue
            if (sch, tbl) not in self:
                for col in intable.columns:
                    if hasattr(col, 'default') \
                            and col.default.startswith('nextval'):
                        stmts.append(col.set_sequence_default())

        return stmts