Exemplo n.º 1
0
 def __init__(self, model):
     """Create a new transaction for `model'."""
     self.m_model = model
     self.m_connection = None
     self.m_obcache = ObjectCache()
     self.m_retry = True
     self.set_isolation_level('SERIALIZABLE')
     self.set_finalization_policy('ROLLBACK')
Exemplo n.º 2
0
class Transaction(object):
    """All objects in a Draco model are part of a transaction.

    A transaction is the only way to query, insert and delete
    objects in a database.
    """

    finalization_policies = set(('ROLLBACK', 'COMMIT'))

    def __init__(self, model):
        """Create a new transaction for `model'."""
        self.m_model = model
        self.m_connection = None
        self.m_obcache = ObjectCache()
        self.m_retry = True
        self.set_isolation_level('SERIALIZABLE')
        self.set_finalization_policy('ROLLBACK')

    def _finalize(self):
        """Finalize the transaction."""
        if self.m_finalization_policy == 'COMMIT':
            self.commit()
        elif self.m_finalization_policy == 'ROLLBACK':
            self.rollback()

    def model(self):
        """Return the model this transaction belongs to."""
        return self.m_model

    def _connect(self):
        """Connect to the database."""
        database = self.m_model.database()
        connection = database.connection()
        self.m_connection = connection

    def cursor(self):
        """Return a new cursor."""
        if self.m_connection is None:
            self._connect()
        return self.m_connection.cursor()

    def set_isolation_level(self, level):
        """Set the transaction isolation level."""
        if self.m_connection is None:
            self._connect()
        database = self.model().database()
        database.set_isolation_level(self.m_connection, level)

    def set_finalization_policy(self, policy):
        """Set the transaction finalization policy to 'policy'.

        The finalization policy can be either 'ROLLBACK' (the default)
        or 'COMMIT'.
        """
        if policy.upper() not in self.finalization_policies:
            m = 'Illegal finalization policy: %s'
            raise ValueError, m % policy
        self.m_finalization_policy = policy.upper()

    def _what_clause(self, typ, alias, attrs, extra_attrs):
        """Return a what clause (SELECT ....) selecting the attributes
        selected by `attrs' of object `typ'.
        """
        if not typ.attributes:
            columns = [ '*' ]
        elif attrs is None:
            # This includes the primary keys.
            columns = [ at.name for at in typ.attributes ]
            columns = [ '%s.%s' % (alias, col) for col in columns ]
        else:
            # Add primary keys explicitly.
            columns = [ at.name for at in typ.primary_key ]
            columns += [ at for at in attrs if at not in columns ]
            columns = [ '%s.%s' % (alias, col) for col in columns ]
        if extra_attrs:
            columns += extra_attrs
        return ','.join(columns)

    def _what_alias(self, desc):
        """Return the alias to use in the what clause."""
        if isinstance(desc, type) and issubclass(desc, Object):
            return desc.name
        while True:
            desc, rolename, rel, kind = desc
            if isinstance(desc, type) and issubclass(desc, Object):
                return rolename

    def _join_condition(self, rolename, rel):
        """Return a join condition (ON ....), joining the relationship
        table `rel' to the entity table of its `role' role.
        """
        role = rel._get_role(rolename)
        name,ent,card,fk = role
        assert len(fk) == len(ent.primary_key)
        cond = []
        for fkat,pkat in zip(fk, ent.primary_key):
            cond.append('%s.%s = %s.%s' %
                        (rel.name, fkat.name, rolename, pkat.name))
        cond = ' AND '.join(cond)
        return cond

    def _from_clause(self, desc):
        """Return a from clause (FROM ....), specified by desc.

        The `desc' parameter may be a single entity, or a tuple of
        (desc, rolename, relationship, kind).
        """
        if isinstance(desc, type) and issubclass(desc, Object):
            return desc.name
        desc, rolename, rel, kind = desc
        t1 = self._from_clause(desc)
        if isinstance(desc, type) and issubclass(desc, Object):
            t1 = '%s AS %s' % (t1, rolename)
        t2 = rel.name
        cond = self._join_condition(rolename, rel)
        clause = '%s %s JOIN %s ON %s' % (t1, kind, t2, cond)
        for role in rel.roles:
            if role[0] == rolename:
                continue
            name,ent,card,fks = role
            t2 = ent.name
            cond = self._join_condition(name, rel)
            clause += ' %s JOIN %s AS %s ON %s' % (kind, t2, name, cond)
        clause = '(%s)' % clause
        return clause

    def _select(self, typ, query, args, lock=False):
        """Do a low-level select query."""
        result = []
        cursor = self.cursor()
        self.retry(cursor.execute, query, args)
        pkcols = [ at.name for at in typ.primary_key ]  # can be empty
        columns = [ de[0] for de in cursor.description ]
        while True:
            row = cursor.fetchone()
            if not row:
                break
            if pkcols:
                # The object cache is used to guarantee that no two
                # python objects point to the same database object.
                try:
                    pk = [ row[columns.index(pk)] for pk in pkcols ]
                except ValueError:
                    raise ModelInternalError, 'Query did not return primary key.'
                obj = self.m_obcache.select(typ, pk)
            else:
                obj = None
            if obj is None:
                obj = typ()
                obj._set_transaction(self)
                obj._select(row, cursor.description, lock)
            if pkcols:
                self.m_obcache.insert(obj)
            result.append(obj)
        return result

    def count(self, typ, where=None, args=None, join=None):
        """Select the number of objects that match a certain query.

        This function mimics the SQL SELECT COUNT(*) construct.
        """
        if isinstance(typ, basestring):
            typ = self.model().object(typ)
        elif not issubclass(typ, Object):
            raise TypeError, 'Expecting `Object\' subclass or object name.'
        if join is None:
            join = typ
        alias = self._what_alias(join)
        query = 'SELECT COUNT(*) '
        query += 'FROM %s ' % self._from_clause(join)
        if where is not None:
            query += ' WHERE %s' % where
        cursor = self.cursor()
        self.retry(cursor.execute, query, args)
        row = cursor.fetchone()
        assert row is not None
        return row[0]

    def select(self, typ, where=None, args=None, order=None, offset=None,
               limit=None, join=None, lock=False, attrs=None, extra_attrs=None):
        """Select an object from the current transaction.

        The keyword interface mimics the SQL SELECT command.
        """
        if isinstance(typ, basestring):
            typ = self.model().object(typ)
        elif not issubclass(typ, Object):
            raise TypeError, 'Expecting `Object\' subclass or object name.'
        if join is None:
            join = typ
        alias = self._what_alias(join)
        query = 'SELECT %s ' % self._what_clause(typ, alias, attrs, extra_attrs)
        query += 'FROM %s ' % self._from_clause(join)
        if where is not None:
            query += ' WHERE %s' % where
        if order is not None:
            query += ' ORDER BY %s ' % order
        if offset is not None:
            query += ' OFFSET %d' % offset
        if limit is not None:
            query += ' LIMIT %d' % limit
        if lock:
            query += ' FOR UPDATE'
        results = self._select(typ, query, args, lock)
        return results

    def insert(self, obj):
        """Insert an object in the transaction.

        The object must be a new object, i.e. it must not have been
        selected from or inserted into any transaction before.
        """
        obj._set_transaction(self)
        obj._insert()
        self.m_obcache.insert(obj)

    def _merge(self, obj, func, lock):
        """Internal helper function for merge()."""
        pkcond = obj._primary_key_condition()
        pkval = obj._primary_key()
        database = self.model().database()
        result = self.select(type(obj), pkcond, pkval, lock=lock)
        if result:
            assert len(result) == 1
            ret = result[0]
            if func is not None:
                func(ret, obj)
        else:
            try:
                self.insert(obj)
            except DatabaseDBAPIError, err:
                if database.is_primary_key_error(err):
                    raise database.serialization_error()
                raise
            ret = obj
        ret._merge(func)
        return ret