示例#1
0
def update_local_id(old_id, new_id, model, session):
    """
    Updates the tuple matching *old_id* with *new_id*, and updates all
    dependent tuples in other tables as well.
    """
    # Updating either the tuple or the dependent tuples first would
    # cause integrity violations if the transaction is flushed in
    # between. The order doesn't matter.
    if model is None:
        raise ValueError("null model given to update_local_id subtransaction")
    # must load fully, don't know yet why
    obj = query_model(session, model).\
        filter_by(**{get_pk(model): old_id}).first()
    setattr(obj, get_pk(model), new_id)

    # Then the dependent ones
    related_tables = get_related_tables(model)
    mapped_fks = ifilter(
        lambda (m, fks): m is not None and fks,
        [(core.synched_models.tables.get(t.name, core.null_model).model,
          get_fks(t, class_mapper(model).mapped_table))
         for t in related_tables])
    for model, fks in mapped_fks:
        for fk in fks:
            for obj in query_model(session, model).filter_by(**{fk: old_id}):
                setattr(obj, fk, new_id)
    session.flush() # raise integrity errors now
示例#2
0
def update_local_id(old_id, new_id, model, session):
    """
    Updates the tuple matching *old_id* with *new_id*, and updates all
    dependent tuples in other tables as well.
    """
    # Updating either the tuple or the dependent tuples first would
    # cause integrity violations if the transaction is flushed in
    # between. The order doesn't matter.
    if model is None:
        raise ValueError("null model given to update_local_id subtransaction")
    # must load fully, don't know yet why
    obj = query_model(session, model).\
        filter_by(**{get_pk(model): old_id}).first()
    setattr(obj, get_pk(model), new_id)

    # Then the dependent ones
    related_tables = get_related_tables(model)
    mapped_fks = ifilter(
        lambda (m, fks): m is not None and fks,
        [(core.synched_models.tables.get(t.name, core.null_model).model,
          get_fks(t,
                  class_mapper(model).mapped_table)) for t in related_tables])
    for model, fks in mapped_fks:
        for fk in fks:
            for obj in query_model(session, model).filter_by(**{fk: old_id}):
                setattr(obj, fk, new_id)
    session.flush()  # raise integrity errors now
示例#3
0
    def fill_for(self,
                 request,
                 swell=False,
                 include_extensions=True,
                 session=None):
        """
        Fills this pull message (response) with versions, operations
        and objects, for the given request (PullRequestMessage).

        The *swell* parameter is deprecated and considered ``True``
        regardless of the value given. This means that parent objects
        will always be added to the message.

        *include_extensions* dictates whether the pull message will
        include model extensions or not.
        """
        assert isinstance(request, PullRequestMessage), "invalid request"
        versions = session.query(Version)
        if request.latest_version_id is not None:
            versions = versions.\
                filter(Version.version_id > request.latest_version_id)
        required_objects = {}
        required_parents = {}
        for v in versions:
            self.versions.append(v)
            for op in v.operations:
                model = op.tracked_model
                if model is None:
                    raise ValueError("operation linked to model %s "\
                                         "which isn't being tracked" % model)
                if model not in pulled_models: continue
                self.operations.append(op)
                if op.command != 'd':
                    pks = required_objects.get(model, set())
                    pks.add(op.row_id)
                    required_objects[model] = pks

        for model, pks in ((m, batch)
                           for m, pks in required_objects.iteritems()
                           for batch in grouper(pks, MAX_SQL_VARIABLES)):
            for obj in query_model(session, model).filter(
                    getattr(model, get_pk(model)).in_(list(pks))).all():
                self.add_object(obj, include_extensions=include_extensions)
                # add parent objects to resolve conflicts in merge
                for pmodel, ppk in parent_references(
                        obj, synched_models.models.keys()):
                    parent_pks = required_parents.get(pmodel, set())
                    parent_pks.add(ppk)
                    required_parents[pmodel] = parent_pks

        for pmodel, ppks in ((m, batch)
                             for m, pks in required_parents.iteritems()
                             for batch in grouper(pks, MAX_SQL_VARIABLES)):
            for parent in query_model(session, pmodel).filter(
                    getattr(pmodel, get_pk(pmodel)).in_(list(ppks))).all():
                self.add_object(parent, include_extensions=include_extensions)
        return self
示例#4
0
 def tracked(o, **kws):
     if _has_delete_functions(o):
         if always: deleted.append((copy(o), None))
         else:
             prev = query_model(session, type(o)).filter_by(
                 **{get_pk(o): getattr(o, get_pk(o), None)}).\
                 first()
             if prev is not None:
                 deleted.append((copy(prev), o))
     return fn(o, **kws)
示例#5
0
    def fill_for(self, request, swell=False, include_extensions=True, session=None):
        """
        Fills this pull message (response) with versions, operations
        and objects, for the given request (PullRequestMessage).

        The *swell* parameter is deprecated and considered ``True``
        regardless of the value given. This means that parent objects
        will always be added to the message.

        *include_extensions* dictates whether the pull message will
        include model extensions or not.
        """
        assert isinstance(request, PullRequestMessage), "invalid request"
        versions = session.query(Version)
        if request.latest_version_id is not None:
            versions = versions.filter(Version.version_id > request.latest_version_id)
        required_objects = {}
        required_parents = {}
        for v in versions:
            self.versions.append(v)
            for op in v.operations:
                model = op.tracked_model
                if model is None:
                    raise ValueError("operation linked to model %s " "which isn't being tracked" % model)
                if model not in pulled_models:
                    continue
                self.operations.append(op)
                if op.command != "d":
                    pks = required_objects.get(model, set())
                    pks.add(op.row_id)
                    required_objects[model] = pks

        for model, pks in (
            (m, batch) for m, pks in required_objects.iteritems() for batch in grouper(pks, MAX_SQL_VARIABLES)
        ):
            for obj in query_model(session, model).filter(getattr(model, get_pk(model)).in_(list(pks))).all():
                self.add_object(obj, include_extensions=include_extensions)
                # add parent objects to resolve conflicts in merge
                for pmodel, ppk in parent_references(obj, synched_models.models.keys()):
                    parent_pks = required_parents.get(pmodel, set())
                    parent_pks.add(ppk)
                    required_parents[pmodel] = parent_pks

        for pmodel, ppks in (
            (m, batch) for m, pks in required_parents.iteritems() for batch in grouper(pks, MAX_SQL_VARIABLES)
        ):
            for parent in query_model(session, pmodel).filter(getattr(pmodel, get_pk(pmodel)).in_(list(ppks))).all():
                self.add_object(parent, include_extensions=include_extensions)
        return self
示例#6
0
def related_local_ids(operation, session):
    """
    For the given operation, return a set of row id values mapped to
    content type ids that correspond to objects that are dependent by
    foreign key on the object being operated upon. The lookups are
    performed in the local database.
    """
    parent_model = operation.tracked_model
    if parent_model is None:
        return set()
    related_tables = get_related_tables(parent_model)

    mapped_fks = ifilter(
        lambda (m, fks): m is not None and fks,
        [(synched_models.tables.get(t.name, null_model).model,
          get_fks(t, class_mapper(parent_model).mapped_table))
         for t in related_tables])
    return set(
        (pk, ct.id)
        for pk, ct in \
            ((getattr(obj, get_pk(obj)), synched_models.models.get(model, None))
             for model, fks in mapped_fks
             for obj in query_model(session, model, only_pk=True).\
                 filter(or_(*(getattr(model, fk) == operation.row_id
                              for fk in fks))).all())
        if ct is not None)
示例#7
0
def related_remote_ids(operation, container):
    """
    Like *related_local_ids*, but the lookups are performed in
    *container*, that's an instance of
    *dbsync.messages.base.BaseMessage*.
    """
    parent_model = operation.tracked_model
    if parent_model is None:
        return set()
    related_tables = get_related_tables(parent_model)

    mapped_fks = ifilter(
        lambda (m, fks): m is not None and fks,
        [(synched_models.tables.get(t.name, null_model).model,
          get_fks(t, class_mapper(parent_model).mapped_table))
         for t in related_tables])
    return set(
        (pk, ct.id)
        for pk, ct in \
            ((getattr(obj, get_pk(obj)), synched_models.models.get(model, None))
             for model, fks in mapped_fks
             for obj in container.query(model).\
                 filter(lambda obj: any(getattr(obj, fk) == operation.row_id
                                        for fk in fks)))
        if ct is not None)
示例#8
0
 def add_object(self, obj, include_extensions=True):
     "Adds an object to the message, if it's not already in."
     class_ = type(obj)
     classname = class_.__name__
     obj_set = self.payload.get(classname, set())
     if ObjectType(classname, getattr(obj, get_pk(class_))) in obj_set:
         return self
     properties = properties_dict(obj)
     if include_extensions:
         for field, ext in model_extensions.get(classname, {}).iteritems():
             _, loadfn, _, _ = ext
             properties[field] = loadfn(obj)
     obj_set.add(ObjectType(
             classname, getattr(obj, get_pk(class_)), **properties))
     self.payload[classname] = obj_set
     return self
示例#9
0
 def add_object(self, obj, include_extensions=True):
     "Adds an object to the message, if it's not already in."
     class_ = type(obj)
     classname = class_.__name__
     obj_set = self.payload.get(classname, set())
     if ObjectType(classname, getattr(obj, get_pk(class_))) in obj_set:
         return self
     properties = properties_dict(obj)
     if include_extensions:
         for field, ext in model_extensions.get(classname, {}).iteritems():
             _, loadfn, _, _ = ext
             properties[field] = loadfn(obj)
     obj_set.add(
         ObjectType(classname, getattr(obj, get_pk(class_)), **properties))
     self.payload[classname] = obj_set
     return self
示例#10
0
 def add_unversioned_operations(self, session=None, include_extensions=True):
     """
     Adds all unversioned operations to this message, including the
     required objects for them to be performed.
     """
     operations = session.query(Operation).\
         filter(Operation.version_id == None).all()
     if any(op.content_type_id not in synched_models.ids
            for op in operations):
         raise ValueError("version includes operation linked "\
                              "to model not currently being tracked")
     required_objects = {}
     for op in operations:
         model = op.tracked_model
         if model not in pushed_models: continue
         self.operations.append(op)
         if op.command != 'd':
             pks = required_objects.get(model, set())
             pks.add(op.row_id)
             required_objects[model] = pks
     for model, pks in ((m, batch)
                        for m, pks in required_objects.iteritems()
                        for batch in grouper(pks, MAX_SQL_VARIABLES)):
         for obj in query_model(session, model).filter(
                 getattr(model, get_pk(model)).in_(list(pks))).all():
             self.add_object(obj, include_extensions=include_extensions)
     if self.key is not None:
         # overwrite since it's probably an incorrect key
         self._sign()
     return self
示例#11
0
def related_local_ids(operation, session):
    """
    For the given operation, return a set of row id values mapped to
    content type ids that correspond to objects that are dependent by
    foreign key on the object being operated upon. The lookups are
    performed in the local database.
    """
    parent_model = operation.tracked_model
    if parent_model is None:
        return set()
    related_tables = get_related_tables(parent_model)

    mapped_fks = ifilter(lambda (m, fks): m is not None and fks,
                         [(synched_models.tables.get(t.name, null_model).model,
                           get_fks(t,
                                   class_mapper(parent_model).mapped_table))
                          for t in related_tables])
    return set(
        (pk, ct.id)
        for pk, ct in \
            ((getattr(obj, get_pk(obj)), synched_models.models.get(model, None))
             for model, fks in mapped_fks
             for obj in query_model(session, model, only_pk=True).\
                 filter(or_(*(getattr(model, fk) == operation.row_id
                              for fk in fks))).all())
        if ct is not None)
示例#12
0
def related_remote_ids(operation, container):
    """
    Like *related_local_ids*, but the lookups are performed in
    *container*, that's an instance of
    *dbsync.messages.base.BaseMessage*.
    """
    parent_model = operation.tracked_model
    if parent_model is None:
        return set()
    related_tables = get_related_tables(parent_model)

    mapped_fks = ifilter(lambda (m, fks): m is not None and fks,
                         [(synched_models.tables.get(t.name, null_model).model,
                           get_fks(t,
                                   class_mapper(parent_model).mapped_table))
                          for t in related_tables])
    return set(
        (pk, ct.id)
        for pk, ct in \
            ((getattr(obj, get_pk(obj)), synched_models.models.get(model, None))
             for model, fks in mapped_fks
             for obj in container.query(model).\
                 filter(lambda obj: any(getattr(obj, fk) == operation.row_id
                                        for fk in fks)))
        if ct is not None)
示例#13
0
 def references(self, obj):
     "Whether this operation references the given object or not."
     if self.row_id != getattr(obj, get_pk(obj), None):
         return False
     model = self.tracked_model
     if model is None:
         return False # operation doesn't even refer to a tracked model
     return model is type(obj)
示例#14
0
 def _from_raw(self, data):
     getm = lambda k: synched_models.model_names.get(k, null_model).model
     for k, v, m in ifilter(lambda (k, v, m): m is not None,
                            imap(lambda (k, v): (k, v, getm(k)),
                                 data['payload'].iteritems())):
         self.payload[k] = set(
             map(lambda dict_: ObjectType(k, dict_[get_pk(m)], **dict_),
                 imap(decode_dict(m), v)))
示例#15
0
 def references(self, obj):
     "Whether this operation references the given object or not."
     if self.row_id != getattr(obj, get_pk(obj), None):
         return False
     model = self.tracked_model
     if model is None:
         return False  # operation doesn't even refer to a tracked model
     return model is type(obj)
示例#16
0
 def _from_raw(self, data):
     getm = lambda k: synched_models.model_names.get(k, null_model).model
     for k, v, m in ifilter(
             lambda (k, v, m): m is not None,
             imap(lambda (k, v): (k, v, getm(k)),
                  data['payload'].iteritems())):
         self.payload[k] = set(
             map(lambda dict_: ObjectType(k, dict_[get_pk(m)], **dict_),
                 imap(decode_dict(m), v)))
示例#17
0
def find_unique_conflicts(push_message, session):
    """
    Returns a list of conflicts caused by unique constraints in the
    given push message contrasted against the database. Each conflict
    is a dictionary with the following fields::

        object: the conflicting object in database, bound to the
                session
        columns: tuple of column names in the unique constraint
        new_values: tuple of values that can be used to update the
                    conflicting object.
    """
    conflicts = []

    for pk, model in ((op.row_id, op.tracked_model)
                      for op in push_message.operations if op.command != 'd'):
        if model is None: continue

        mt = class_mapper(model).mapped_table
        if isinstance(mt, Join):
            constraints = mt.left.constraints.union(mt.right.constraints)
        else:
            constraints = mt.constraints

        for constraint in [
                c for c in constraints if isinstance(c, UniqueConstraint)
        ]:

            unique_columns = tuple(col.name for col in constraint.columns)
            remote_obj = push_message.query(model).\
                filter(attr('__pk__') == pk).first()
            remote_values = tuple(
                getattr(remote_obj, col, None) for col in unique_columns)

            if all(value is None for value in remote_values): continue
            local_obj = query_model(session, model).\
                filter_by(**dict(list(zip(unique_columns, remote_values)))).first()
            if local_obj is None: continue
            local_pk = getattr(local_obj, get_pk(model))
            if local_pk == pk: continue

            push_obj = push_message.query(model).\
                filter(attr('__pk__') == local_pk).first()
            if push_obj is None: continue  # push will fail

            conflicts.append({
                'object':
                local_obj,
                'columns':
                unique_columns,
                'new_values':
                tuple(getattr(push_obj, col) for col in unique_columns)
            })

    return conflicts
示例#18
0
 def verify_constraint(model, columns, values):
     """
     Checks to see whether some local object exists with
     conflicting values.
     """
     match = query_model(session, model, only_pk=True).\
         options(*(undefer(column) for column in columns)).\
         filter_by(**dict((column, value)
                          for column, value in izip(columns, values))).first()
     pk = get_pk(model)
     return match, getattr(match, pk, None)
示例#19
0
 def verify_constraint(model, columns, values):
     """
     Checks to see whether some local object exists with
     conflicting values.
     """
     match = query_model(session, model, only_pk=True).\
         options(*(undefer(column) for column in columns)).\
         filter_by(**dict((column, value)
                          for column, value in izip(columns, values))).first()
     pk = get_pk(model)
     return match, getattr(match, pk, None)
示例#20
0
def create_payload_request_message(obj: SQLClass, name: str = "default"):
    id_field = get_pk(obj)
    res = dict(type="request_field_payload",
               name=name,
               table=obj.__tablename__,
               id_field=id_field,
               id=getattr(obj, id_field),
               class_name=obj.__class__.__name__,
               package_name=obj.__class__.__module__)

    from dbsync.messages.codecs import SyncdbJSONEncoder
    return json.dumps(res, cls=SyncdbJSONEncoder)
示例#21
0
 def add_object(self, obj, include_extensions=True):
     """Adds an object to the message, if it's not already in."""
     class_ = type(obj)
     classname = class_.__name__
     obj_set = self.payload.get(classname, set())
     objt = ObjectType(classname, getattr(obj, get_pk(class_)))
     if objt in obj_set:
         return self
     properties = properties_dict(obj)
     if include_extensions:
         ext: ExtensionField
         extensions: List[Extension] = get_model_extensions_for_obj(obj)
         for extension in extensions:
             for field, ext in list(extension.fields.items()):
                 loadfn = ext.loadfn
                 if loadfn:
                     properties[field] = loadfn(obj)
     obj_set.add(
         ObjectType(classname, getattr(obj, get_pk(class_)), **properties))
     self.payload[classname] = obj_set
     return self
示例#22
0
def related_local_ids(operation, session):
    """
    For the given operation, return a set of row id values mapped to
    content type ids that correspond to objects that are dependent by
    foreign key on the object being operated upon. The lookups are
    performed in the local database.
    """
    parent_model = operation.tracked_model
    if parent_model is None:
        return set()
    related_tables = get_related_tables(parent_model)

    mapped_fks = [
        m_fks
        for m_fks
        in [
            (
                synched_models.tables.get(entity_name(t), null_model).model,
                get_fks(t, class_mapper(parent_model).mapped_table)
            )
            for t
            in related_tables
        ]
        if m_fks[0] is not None and m_fks[1]
    ]
    try:
        return set(
            (pk, ct.id)
            for pk, ct
            in (
                (getattr(obj, get_pk(obj)), synched_models.models.get(model, None))
                for model, fks in mapped_fks
                for obj in query_model(session, model) \
                    # removed the pk_only param, because that fails with joins
                    .filter(
                        or_(
                            *(
                                getattr(model, fk) == operation.row_id
                                for fk
                                in fks
                            )
                        )
                    ).all()
            )
            if ct is not None
        )
    except Exception as ex:
        logger.exception(f"collecting conflicts failed: {ex}")
        r0 = [query_model(session, model) for model, fks in mapped_fks]
        r1 = [query_model(session, model).all() for model, fks in mapped_fks]
        raise
示例#23
0
def is_synched(obj, session=None):
    """
    Returns whether the given tracked object is synched.

    Raises a TypeError if the given object is not being tracked
    (i.e. the content type doesn't exist).
    """
    if type(obj) not in synched_models.models:
        raise TypeError("the given object of class {0} isn't being tracked".\
                            format(obj.__class__.__name__))
    session = Session()
    last_op = session.query(Operation).\
        filter(Operation.content_type_id == synched_models.models[type(obj)].id,
               Operation.row_id == getattr(obj, get_pk(obj))).\
               order_by(Operation.order.desc()).first()
    return last_op is None or last_op.version_id is not None
示例#24
0
def max_local(sa_class, session):
    """
    Returns the maximum primary key used for the given table.
    """
    engine = session.bind
    dialect = engine.name
    table_name = class_mapper(sa_class).mapped_table.name
    # default, strictly incorrect query
    found = session.query(func.max(getattr(sa_class, get_pk(sa_class)))).scalar()
    if dialect == 'sqlite':
        cursor = engine.execute("SELECT seq FROM sqlite_sequence WHERE name = ?",
                                table_name)
        result = cursor.fetchone()[0]
        cursor.close()
        return max(result, found)
    return found
示例#25
0
def max_local(sa_class, session):
    """
    Returns the maximum primary key used for the given table.
    """
    engine = session.bind
    dialect = engine.name
    table_name = class_mapper(sa_class).mapped_table.name
    # default, strictly incorrect query
    found = session.query(func.max(getattr(sa_class, get_pk(sa_class)))).scalar()
    if dialect == 'sqlite':
        cursor = engine.execute("SELECT seq FROM sqlite_sequence WHERE name = ?",
                                table_name)
        result = cursor.fetchone()[0]
        cursor.close()
        return max(result, found)
    return found
示例#26
0
def find_unique_conflicts(push_message, session):
    """
    Returns a list of conflicts caused by unique constraints in the
    given push message contrasted against the database. Each conflict
    is a dictionary with the following fields::

        object: the conflicting object in database, bound to the
                session
        columns: tuple of column names in the unique constraint
        new_values: tuple of values that can be used to update the
                    conflicting object.
    """
    conflicts = []

    for pk, model in ((op.row_id, op.tracked_model)
                      for op in push_message.operations
                      if op.command != 'd'):
        if model is None: continue

        for constraint in ifilter(lambda c: isinstance(c, UniqueConstraint),
                                  class_mapper(model).mapped_table.constraints):

            unique_columns = tuple(col.name for col in constraint.columns)
            remote_obj = push_message.query(model).\
                filter(attr('__pk__') == pk).first()
            remote_values = tuple(getattr(remote_obj, col, None)
                                  for col in unique_columns)

            if all(value is None for value in remote_values): continue
            local_obj = query_model(session, model).\
                filter_by(**dict(izip(unique_columns, remote_values))).first()
            if local_obj is None: continue
            local_pk = getattr(local_obj, get_pk(model))
            if local_pk == pk: continue

            push_obj = push_message.query(model).\
                filter(attr('__pk__') == local_pk).first()
            if push_obj is None: continue # push will fail

            conflicts.append(
                {'object': local_obj,
                 'columns': unique_columns,
                 'new_values': tuple(getattr(push_obj, col)
                                     for col in unique_columns)})

    return conflicts
示例#27
0
    def _from_raw(self, data):
        def getm(k):
            res = synched_models.model_names.get(k, null_model).model
            return res

        payload_items = list(data['payload'].items())

        kvms = [(k, v, getm(k)) for (k, v) in iter(payload_items)]

        for (k, v, m) in \
                [k_v_m
                 for k_v_m in kvms
                 if k_v_m[2] is not None
                 ]:
            # breakpoint()
            self.payload[k] = set([
                ObjectType(k, dict_[get_pk(m)], **dict_)
                for dict_ in map(decode_dict(m), v)
            ])
示例#28
0
    def add_operation(self, op, swell=True, session=None):
        """
        Adds an operation to the message, including the required
        object if it's possible to include it.

        If *swell* is given and set to ``False``, the operation and
        object will be added bare, without parent objects. Otherwise,
        the parent objects will be added to aid in conflict
        resolution.

        A delete operation doesn't include the associated object. If
        *session* is given, the procedure won't instantiate a new
        session.

        This operation might fail, (due to database inconsitency) in
        which case the internal state of the message won't be affected
        (i.e. it won't end in an inconsistent state).

        DEPRECATED in favor of `fill_for`
        """
        model = op.tracked_model
        if model is None:
            raise ValueError("operation linked to model %s "\
                                 "which isn't being tracked" % model)
        if model not in pulled_models:
            return self
        obj = query_model(session, model).\
            filter_by(**{get_pk(model): op.row_id}).first() \
            if op.command != 'd' else None
        self.operations.append(op)
        # if the object isn't there it's because the operation is old,
        # and should be able to be compressed out when performing the
        # conflict resolution phase
        if obj is not None:
            self.add_object(obj)
            if swell:
                # add parent objects to resolve possible conflicts in merge
                for parent in parent_objects(obj, synched_models.models.keys(),
                                             session):
                    self.add_object(parent)
        return self
示例#29
0
    def add_operation(self, op, swell=True, session=None):
        """
        Adds an operation to the message, including the required
        object if it's possible to include it.

        If *swell* is given and set to ``False``, the operation and
        object will be added bare, without parent objects. Otherwise,
        the parent objects will be added to aid in conflict
        resolution.

        A delete operation doesn't include the associated object. If
        *session* is given, the procedure won't instantiate a new
        session.

        This operation might fail, (due to database inconsitency) in
        which case the internal state of the message won't be affected
        (i.e. it won't end in an inconsistent state).

        DEPRECATED in favor of `fill_for`
        """
        model = op.tracked_model
        if model is None:
            raise ValueError("operation linked to model %s " "which isn't being tracked" % model)
        if model not in pulled_models:
            return self
        obj = query_model(session, model).filter_by(**{get_pk(model): op.row_id}).first() if op.command != "d" else None
        self.operations.append(op)
        # if the object isn't there it's because the operation is old,
        # and should be able to be compressed out when performing the
        # conflict resolution phase
        if obj is not None:
            self.add_object(obj)
            if swell:
                # add parent objects to resolve possible conflicts in merge
                for parent in parent_objects(obj, synched_models.models.keys(), session):
                    self.add_object(parent)
        return self
示例#30
0
def compress(session=None):
    """
    Compresses unversioned operations in the database.

    For each row in the operations table, this deletes unnecesary
    operations that would otherwise bloat the message.

    This procedure is called internally before the 'push' request
    happens, and before the local 'merge' happens.
    """
    unversioned = session.query(Operation).\
        filter(Operation.version_id == None).order_by(Operation.order.desc())
    seqs = group_by(lambda op: (op.row_id, op.content_type_id), unversioned)

    # Check errors on sequences
    for seq in seqs.itervalues():
        _assert_operation_sequence(seq, session)

    for seq in ifilter(lambda seq: len(seq) > 1, seqs.itervalues()):
        if seq[-1].command == 'i':
            if all(op.command == 'u' for op in seq[:-1]):
                # updates are superfluous
                map(session.delete, seq[:-1])
            elif seq[0].command == 'd':
                # it's as if the object never existed
                map(session.delete, seq)
        elif seq[-1].command == 'u':
            if all(op.command == 'u' for op in seq[:-1]):
                # leave a single update
                map(session.delete, seq[1:])
            elif seq[0].command == 'd':
                # leave the delete statement
                map(session.delete, seq[1:])
    session.flush()

    # repair inconsistencies
    for operation in session.query(Operation).\
            filter(Operation.version_id == None).\
            order_by(Operation.order.desc()).all():
        session.flush()
        model = operation.tracked_model
        if not model:
            logger.error(
                "operation linked to content type "
                "not tracked: %s" % operation.content_type_id)
            continue
        if operation.command in ('i', 'u'):
            if query_model(session, model, only_pk=True).\
                    filter_by(**{get_pk(model): operation.row_id}).count() == 0:
                logger.warning(
                    "deleting operation %s for model %s "
                    "for absence of backing object" % (operation, model.__name__))
                session.delete(operation)
                continue
        if operation.command == 'u':
            subsequent = session.query(Operation).\
                filter(Operation.content_type_id == operation.content_type_id,
                       Operation.version_id == None,
                       Operation.row_id == operation.row_id,
                       Operation.order > operation.order).all()
            if any(op.command == 'i' for op in subsequent) and \
                    all(op.command != 'd' for op in subsequent):
                logger.warning(
                    "deleting update operation %s for model %s "
                    "for preceding an insert operation" %\
                        (operation, model.__name__))
                session.delete(operation)
                continue
        if session.query(Operation).\
                filter(Operation.content_type_id == operation.content_type_id,
                       Operation.command == operation.command,
                       Operation.version_id == None,
                       Operation.row_id == operation.row_id,
                       Operation.order != operation.order).count() > 0:
            logger.warning(
                "deleting operation %s for model %s "
                "for being redundant after compression" %\
                    (operation, model.__name__))
            session.delete(operation)
            continue
    return session.query(Operation).\
        filter(Operation.version_id == None).\
        order_by(Operation.order.asc()).all()
示例#31
0
def handle_push(data, session=None):
    """
    Handle the push request and return a dictionary object to be sent
    back to the node.

    If the push is rejected, this procedure will raise a
    dbsync.server.handlers.PushRejected exception.

    *data* must be a dictionary-like object, usually the product of
    parsing a JSON string.
    """
    message = None
    try:
        message = PushMessage(data)
    except KeyError:
        raise PushRejected("request object isn't a valid PushMessage", data)
    latest_version_id = core.get_latest_version_id(session=session)
    if latest_version_id != message.latest_version_id:
        exc = "version identifier isn't the latest one; "\
            "given: %s" % message.latest_version_id
        if latest_version_id is None:
            raise PushRejected(exc)
        if message.latest_version_id is None:
            raise PullSuggested(exc)
        if message.latest_version_id < latest_version_id:
            raise PullSuggested(exc)
        raise PushRejected(exc)
    if not message.operations:
        raise PushRejected("message doesn't contain operations")
    if not message.islegit(session):
        raise PushRejected("message isn't properly signed")

    for listener in before_push:
        listener(session, message)

    # I) detect unique constraint conflicts and resolve them if possible
    unique_conflicts = find_unique_conflicts(message, session)
    conflicting_objects = set()
    for uc in unique_conflicts:
        obj = uc['object']
        conflicting_objects.add(obj)
        for key, value in izip(uc['columns'], uc['new_values']):
            setattr(obj, key, value)
    for obj in conflicting_objects:
        make_transient(obj) # remove from session
    for model in set(type(obj) for obj in conflicting_objects):
        pk_name = get_pk(model)
        pks = [getattr(obj, pk_name)
               for obj in conflicting_objects
               if type(obj) is model]
        session.query(model).filter(getattr(model, pk_name).in_(pks)).\
            delete(synchronize_session=False) # remove from the database
    session.add_all(conflicting_objects) # reinsert
    session.flush()

    # II) perform the operations
    operations = filter(lambda o: o.tracked_model is not None, message.operations)
    try:
        for op in operations:
            op.perform(message, session, message.node_id)
    except OperationError as e:
        logger.exception(u"Couldn't perform operation in push from node %s.",
                         message.node_id)
        raise PushRejected("at least one operation couldn't be performed",
                           *e.args)

    # III) insert a new version
    version = Version(created=datetime.datetime.now(), node_id=message.node_id)
    session.add(version)

    # IV) insert the operations, discarding the 'order' column
    for op in sorted(operations, key=attr('order')):
        new_op = Operation()
        for k in ifilter(lambda k: k != 'order', properties_dict(op)):
            setattr(new_op, k, getattr(op, k))
        session.add(new_op)
        new_op.version = version
        session.flush()

    for listener in after_push:
        listener(session, message)

    # return the new version id back to the node
    return {'new_version_id': version.version_id}
示例#32
0
    async def perform_async(
        operation: "Operation",
        container: "BaseMessage",
        session: Session,
        node_id=None,
        websocket: Optional[WebSocketCommonProtocol] = None
    ) -> (Optional[SQLClass], Optional[SQLClass]):
        """
        Performs *operation*, looking for required data in
        *container*, and using *session* to perform it.

        *container* is an instance of
        dbsync.messages.base.BaseMessage.

        *node_id* is the node responsible for the operation, if known
        (else ``None``).

        If at any moment this operation fails for predictable causes,
        it will raise an *OperationError*.
        """
        from dbsync.core import mode
        model: DeclarativeMeta = operation.tracked_model
        res: Tuple[Optional[SQLClass], Optional[SQLClass]] = (None, None)
        if model is None:
            raise OperationError("no content type for this operation",
                                 operation)

        if operation.command == 'i':
            # check if the given object is already in the database
            obj = query_model(session, model). \
                filter(getattr(model, get_pk(model)) == operation.row_id).first()

            # retrieve the object from the PullMessage
            qu = container.query(model). \
                filter(attr('__pk__') == operation.row_id)
            # breakpoint()
            pull_obj = qu.first()
            # pull_obj._session = session
            if pull_obj is None:
                raise OperationError(
                    f"no object backing the operation in container on {mode}",
                    operation)
            if obj is None:
                logger.info(
                    f"insert: calling request_payloads_for_extension for: {pull_obj.id}"
                )
                try:
                    operation.call_before_operation_fn(session, pull_obj)
                    await request_payloads_for_extension(
                        operation, pull_obj, websocket, session)
                    session.add(pull_obj)
                    res = pull_obj, None
                except SkipOperation as e:
                    logger.info(f"operation {operation} skipped")
                # operation.call_after_operation_fn(pull_obj, session)
            else:
                # Don't raise an exception if the incoming object is
                # exactly the same as the local one.
                if properties_dict(obj) == properties_dict(pull_obj):
                    logger.warning("insert attempted when an identical object "
                                   "already existed in local database: "
                                   "model {0} pk {1}".format(
                                       model.__name__, operation.row_id))
                else:
                    raise OperationError(
                        "insert attempted when the object already existed: "
                        "model {0} pk {1}".format(model.__name__,
                                                  operation.row_id))

        elif operation.command == 'u':
            obj = query_model(session, model). \
                filter(getattr(model, get_pk(model)) == operation.row_id).one_or_none()
            if obj is not None:
                logger.info(
                    f"update: calling request_payloads_for_extension for: {obj.id}"
                )
                # breakpoint()
            else:
                # For now, the record will be created again, but is an
                # error because nothing should be deleted without
                # using dbsync
                # raise OperationError(
                #     "the referenced object doesn't exist in database", operation)
                # addendum:
                # this can happen when tracking of an object has been suppressed and
                # later been activated during a 'u' operation,
                # so we keep this logic
                logger.warning(
                    "The referenced object doesn't exist in database. "
                    "Node %s. Operation %s", node_id, operation)

            # get new object from the PushMessage
            pull_obj = container.query(model). \
                filter(attr('__pk__') == operation.row_id).first()
            if pull_obj is None:
                raise OperationError(
                    "no object backing the operation in container", operation)

            try:
                operation.call_before_operation_fn(session, pull_obj, obj)
                await request_payloads_for_extension(operation, pull_obj,
                                                     websocket, session)
                if obj is None:
                    logger.warn(f"obj is None")
                old_obj = copy(obj) if obj is not None else None
                session.merge(pull_obj)
                res = pull_obj, old_obj
            except SkipOperation as e:
                logger.info(f"operation {operation} skipped")

            # operation.call_after_operation_fn(pull_obj, session)

        elif operation.command == 'd':
            try:
                obj = query_model(session, model, only_pk=True). \
                    filter(getattr(model, get_pk(model)) == operation.row_id).first()
            except NoSuchColumnError as ex:
                # for joins only_pk doesnt seem to work
                obj = query_model(session, model, only_pk=False). \
                    filter(getattr(model, get_pk(model)) == operation.row_id).first()

            if obj is None:
                # The object is already deleted in the server
                # The final state in node and server are the same. But
                # it's an error because nothing should be deleted
                # without using dbsync
                logger.warning(
                    "The referenced object doesn't exist in database. "
                    "Node %s. Operation %s", node_id, operation)
            else:
                try:
                    # breakpoint()
                    operation.call_before_operation_fn(session, obj)
                    session.delete(obj)
                    res = obj, None
                except SkipOperation as e:
                    logger.info(f"operation {operation} skipped")

        else:
            raise OperationError(
                "the operation doesn't specify a valid command ('i', 'u', 'd')",
                operation)

        return res
示例#33
0
def max_remote(model, container):
    """
    Returns the maximum value for the primary key of the given model
    in the container.
    """
    return max(getattr(obj, get_pk(obj)) for obj in container.query(model))
示例#34
0
def compress(session=None) -> List[Operation]:
    """
    Compresses unversioned operations in the database.

    For each row in the operations table, this deletes unnecesary
    operations that would otherwise bloat the message.

    This procedure is called internally before the 'push' request
    happens, and before the local 'merge' happens.
    """
    unversioned: Query = session.query(Operation).\
        filter(Operation.version_id == None).order_by(Operation.order.desc())
    seqs = group_by(lambda op: (op.row_id, op.content_type_id), unversioned)

    # Check errors on sequences
    for seq in list(seqs.values()):
        _assert_operation_sequence(seq, session)

    for seq in [seq for seq in iter(list(seqs.values())) if len(seq) > 1]:
        if seq[-1].command == 'i':
            if all(op.command == 'u' for op in seq[:-1]):
                # updates are superfluous
                list(map(session.delete, seq[:-1]))
            elif seq[0].command == 'd':
                # it's as if the object never existed
                list(map(session.delete, seq))
        elif seq[-1].command == 'u':
            if all(op.command == 'u' for op in seq[:-1]):
                # leave a single update
                list(map(session.delete, seq[1:]))
            elif seq[0].command == 'd':
                # leave the delete statement
                list(map(session.delete, seq[1:]))
    session.flush()

    # repair inconsistencies
    for operation in session.query(Operation).\
            filter(Operation.version_id == None).\
            order_by(Operation.order.desc()).all():
        session.flush()
        model = operation.tracked_model
        if not model:
            logger.error("operation linked to content type "
                         "not tracked: %s" % operation.content_type_id)
            continue
        if operation.command in ('i', 'u'):
            if query_model(session, model, only_pk=True).\
                    filter_by(**{get_pk(model): operation.row_id}).count() == 0:
                logger.warning("deleting operation %s for model %s "
                               "for absence of backing object" %
                               (operation, model.__name__))
                session.delete(operation)
                continue
        if operation.command == 'u':
            subsequent = session.query(Operation).\
                filter(Operation.content_type_id == operation.content_type_id,
                       Operation.version_id == None,
                       Operation.row_id == operation.row_id,
                       Operation.order > operation.order).all()
            if any(op.command == 'i' for op in subsequent) and \
                    all(op.command != 'd' for op in subsequent):
                logger.warning(
                    "deleting update operation %s for model %s "
                    "for preceding an insert operation" %\
                        (operation, model.__name__))
                session.delete(operation)
                continue
        if session.query(Operation).\
                filter(Operation.content_type_id == operation.content_type_id,
                       Operation.command == operation.command,
                       Operation.version_id == None,
                       Operation.row_id == operation.row_id,
                       Operation.order != operation.order).count() > 0:
            logger.warning(
                "deleting operation %s for model %s "
                "for being redundant after compression" %\
                    (operation, model.__name__))
            session.delete(operation)
            continue

    session.commit()
    return session.query(Operation).\
        filter(Operation.version_id == None).\
        order_by(Operation.order.asc()).all()
示例#35
0
def merge(pull_message, session=None):
    """
    Merges a message from the server with the local database.

    *pull_message* is an instance of dbsync.messages.pull.PullMessage.
    """
    if not isinstance(pull_message, PullMessage):
        raise TypeError("need an instance of dbsync.messages.pull.PullMessage "
                        "to perform the local merge operation")
    valid_cts = set(ct for ct in core.synched_models.ids)

    unversioned_ops = compress(session=session)
    pull_ops = filter(
        attr('content_type_id').in_(valid_cts), pull_message.operations)
    pull_ops = compressed_operations(pull_ops)

    # I) first phase: resolve unique constraint conflicts if
    # possible. Abort early if a human error is detected
    unique_conflicts, unique_errors = find_unique_conflicts(
        pull_ops, unversioned_ops, pull_message, session)

    if unique_errors:
        raise UniqueConstraintError(unique_errors)

    conflicting_objects = set()
    for uc in unique_conflicts:
        obj = uc['object']
        conflicting_objects.add(obj)
        for key, value in izip(uc['columns'], uc['new_values']):
            setattr(obj, key, value)
    # Resolve potential cyclical conflicts by deleting and reinserting
    for obj in conflicting_objects:
        make_transient(obj)  # remove from session
    for model in set(type(obj) for obj in conflicting_objects):
        pk_name = get_pk(model)
        pks = [
            getattr(obj, pk_name) for obj in conflicting_objects
            if type(obj) is model
        ]
        session.query(model).filter(getattr(model, pk_name).in_(pks)).\
            delete(synchronize_session=False) # remove from the database
    session.add_all(conflicting_objects)  # reinsert them
    session.flush()

    # II) second phase: detect conflicts between pulled operations and
    # unversioned ones
    direct_conflicts = find_direct_conflicts(pull_ops, unversioned_ops)

    # in which the delete operation is registered on the pull message
    dependency_conflicts = find_dependency_conflicts(pull_ops, unversioned_ops,
                                                     session)

    # in which the delete operation was performed locally
    reversed_dependency_conflicts = find_reversed_dependency_conflicts(
        pull_ops, unversioned_ops, pull_message)

    insert_conflicts = find_insert_conflicts(pull_ops, unversioned_ops)

    # III) third phase: perform pull operations, when allowed and
    # while resolving conflicts
    def extract(op, conflicts):
        return [local for remote, local in conflicts if remote is op]

    def purgelocal(local):
        session.delete(local)
        exclude = lambda tup: tup[1] is not local
        mfilter(exclude, direct_conflicts)
        mfilter(exclude, dependency_conflicts)
        mfilter(exclude, reversed_dependency_conflicts)
        mfilter(exclude, insert_conflicts)
        unversioned_ops.remove(local)

    for pull_op in pull_ops:
        # flag to control whether the remote operation is free of obstacles
        can_perform = True
        # flag to detect the early exclusion of a remote operation
        reverted = False
        # the class of the operation
        class_ = pull_op.tracked_model

        direct = extract(pull_op, direct_conflicts)
        if direct:
            if pull_op.command == 'd':
                can_perform = False
            for local in direct:
                pair = (pull_op.command, local.command)
                if pair == ('u', 'u'):
                    can_perform = False  # favor local changes over remote ones
                elif pair == ('u', 'd'):
                    pull_op.command = 'i'  # negate the local delete
                    purgelocal(local)
                elif pair == ('d', 'u'):
                    local.command = 'i'  # negate the remote delete
                    session.flush()
                    reverted = True
                else:  # ('d', 'd')
                    purgelocal(local)

        dependency = extract(pull_op, dependency_conflicts)
        if dependency and not reverted:
            can_perform = False
            order = min(op.order for op in unversioned_ops)
            # first move all operations further in order, to make way
            # for the new one
            for op in unversioned_ops:
                op.order = op.order + 1
            session.flush()
            # then create operation to reflect the reinsertion and
            # maintain a correct operation history
            session.add(
                Operation(row_id=pull_op.row_id,
                          content_type_id=pull_op.content_type_id,
                          command='i',
                          order=order))

        reversed_dependency = extract(pull_op, reversed_dependency_conflicts)
        for local in reversed_dependency:
            # reinsert record
            local.command = 'i'
            local.perform(pull_message, session)
            # delete trace of deletion
            purgelocal(local)

        insert = extract(pull_op, insert_conflicts)
        for local in insert:
            session.flush()
            next_id = max(max_remote(class_, pull_message),
                          max_local(class_, session)) + 1
            update_local_id(local.row_id, next_id, class_, session)
            local.row_id = next_id
        if can_perform:
            pull_op.perform(pull_message, session)

            session.flush()

    # IV) fourth phase: insert versions from the pull_message
    for pull_version in pull_message.versions:
        session.add(pull_version)
示例#36
0
def merge(pull_message, session=None):
    """
    Merges a message from the server with the local database.

    *pull_message* is an instance of dbsync.messages.pull.PullMessage.
    """
    if not isinstance(pull_message, PullMessage):
        raise TypeError("need an instance of dbsync.messages.pull.PullMessage "
                        "to perform the local merge operation")
    valid_cts = set(ct for ct in core.synched_models.ids)

    unversioned_ops = compress(session=session)
    pull_ops = filter(attr('content_type_id').in_(valid_cts),
                      pull_message.operations)
    pull_ops = compressed_operations(pull_ops)

    # I) first phase: resolve unique constraint conflicts if
    # possible. Abort early if a human error is detected
    unique_conflicts, unique_errors = find_unique_conflicts(
        pull_ops, unversioned_ops, pull_message, session)

    if unique_errors:
        raise UniqueConstraintError(unique_errors)

    conflicting_objects = set()
    for uc in unique_conflicts:
        obj = uc['object']
        conflicting_objects.add(obj)
        for key, value in izip(uc['columns'], uc['new_values']):
            setattr(obj, key, value)
    # Resolve potential cyclical conflicts by deleting and reinserting
    for obj in conflicting_objects:
        make_transient(obj) # remove from session
    for model in set(type(obj) for obj in conflicting_objects):
        pk_name = get_pk(model)
        pks = [getattr(obj, pk_name)
               for obj in conflicting_objects
               if type(obj) is model]
        session.query(model).filter(getattr(model, pk_name).in_(pks)).\
            delete(synchronize_session=False) # remove from the database
    session.add_all(conflicting_objects) # reinsert them
    session.flush()

    # II) second phase: detect conflicts between pulled operations and
    # unversioned ones
    direct_conflicts = find_direct_conflicts(pull_ops, unversioned_ops)

    # in which the delete operation is registered on the pull message
    dependency_conflicts = find_dependency_conflicts(
        pull_ops, unversioned_ops, session)

    # in which the delete operation was performed locally
    reversed_dependency_conflicts = find_reversed_dependency_conflicts(
        pull_ops, unversioned_ops, pull_message)

    insert_conflicts = find_insert_conflicts(pull_ops, unversioned_ops)

    # III) third phase: perform pull operations, when allowed and
    # while resolving conflicts
    def extract(op, conflicts):
        return [local for remote, local in conflicts if remote is op]

    def purgelocal(local):
        session.delete(local)
        exclude = lambda tup: tup[1] is not local
        mfilter(exclude, direct_conflicts)
        mfilter(exclude, dependency_conflicts)
        mfilter(exclude, reversed_dependency_conflicts)
        mfilter(exclude, insert_conflicts)
        unversioned_ops.remove(local)

    for pull_op in pull_ops:
        # flag to control whether the remote operation is free of obstacles
        can_perform = True
        # flag to detect the early exclusion of a remote operation
        reverted = False
        # the class of the operation
        class_ = pull_op.tracked_model

        direct = extract(pull_op, direct_conflicts)
        if direct:
            if pull_op.command == 'd':
                can_perform = False
            for local in direct:
                pair = (pull_op.command, local.command)
                if pair == ('u', 'u'):
                    can_perform = False # favor local changes over remote ones
                elif pair == ('u', 'd'):
                    pull_op.command = 'i' # negate the local delete
                    purgelocal(local)
                elif pair == ('d', 'u'):
                    local.command = 'i' # negate the remote delete
                    session.flush()
                    reverted = True
                else: # ('d', 'd')
                    purgelocal(local)

        dependency = extract(pull_op, dependency_conflicts)
        if dependency and not reverted:
            can_perform = False
            order = min(op.order for op in unversioned_ops)
            # first move all operations further in order, to make way
            # for the new one
            for op in unversioned_ops:
                op.order = op.order + 1
            session.flush()
            # then create operation to reflect the reinsertion and
            # maintain a correct operation history
            session.add(Operation(row_id=pull_op.row_id,
                                  content_type_id=pull_op.content_type_id,
                                  command='i',
                                  order=order))

        reversed_dependency = extract(pull_op, reversed_dependency_conflicts)
        for local in reversed_dependency:
            # reinsert record
            local.command = 'i'
            local.perform(pull_message, session)
            # delete trace of deletion
            purgelocal(local)

        insert = extract(pull_op, insert_conflicts)
        for local in insert:
            session.flush()
            next_id = max(max_remote(class_, pull_message),
                          max_local(class_, session)) + 1
            update_local_id(local.row_id, next_id, class_, session)
            local.row_id = next_id
        if can_perform:
            pull_op.perform(pull_message, session)

            session.flush()

    # IV) fourth phase: insert versions from the pull_message
    for pull_version in pull_message.versions:
        session.add(pull_version)
示例#37
0
def max_remote(model, container):
    """
    Returns the maximum value for the primary key of the given model
    in the container.
    """
    return max(getattr(obj, get_pk(obj)) for obj in container.query(model))
示例#38
0
    def perform(operation, container, session, node_id=None):
        """
        Performs *operation*, looking for required data in
        *container*, and using *session* to perform it.

        *container* is an instance of
        dbsync.messages.base.BaseMessage.

        *node_id* is the node responsible for the operation, if known
        (else ``None``).

        If at any moment this operation fails for predictable causes,
        it will raise an *OperationError*.
        """
        model = operation.tracked_model
        if model is None:
            raise OperationError("no content type for this operation", operation)

        if operation.command == 'i':
            obj = query_model(session, model).\
                filter(getattr(model, get_pk(model)) == operation.row_id).first()
            pull_obj = container.query(model).\
                filter(attr('__pk__') == operation.row_id).first()
            if pull_obj is None:
                raise OperationError(
                    "no object backing the operation in container", operation)
            if obj is None:
                session.add(pull_obj)
            else:
                # Don't raise an exception if the incoming object is
                # exactly the same as the local one.
                if properties_dict(obj) == properties_dict(pull_obj):
                    logger.warning(u"insert attempted when an identical object "
                                   u"already existed in local database: "
                                   u"model {0} pk {1}".format(model.__name__,
                                                              operation.row_id))
                else:
                    raise OperationError(
                        u"insert attempted when the object already existed: "
                        u"model {0} pk {1}".format(model.__name__,
                                                   operation.row_id))

        elif operation.command == 'u':
            obj = query_model(session, model).\
                filter(getattr(model, get_pk(model)) == operation.row_id).first()
            if obj is None:
                # For now, the record will be created again, but is an
                # error because nothing should be deleted without
                # using dbsync
                # raise OperationError(
                #     "the referenced object doesn't exist in database", operation)
                logger.warning(
                    u"The referenced object doesn't exist in database. "
                    u"Node %s. Operation %s",
                    node_id,
                    operation)

            pull_obj = container.query(model).\
                filter(attr('__pk__') == operation.row_id).first()
            if pull_obj is None:
                raise OperationError(
                    "no object backing the operation in container", operation)
            session.merge(pull_obj)

        elif operation.command == 'd':
            obj = query_model(session, model, only_pk=True).\
                filter(getattr(model, get_pk(model)) == operation.row_id).first()
            if obj is None:
                # The object is already deleted in the server
                # The final state in node and server are the same. But
                # it's an error because nothing should be deleted
                # without using dbsync
                logger.warning(
                    "The referenced object doesn't exist in database. "
                    u"Node %s. Operation %s",
                    node_id,
                    operation)
            else:
                session.delete(obj)

        else:
            raise OperationError(
                "the operation doesn't specify a valid command ('i', 'u', 'd')",
                operation)
示例#39
0
    def fill_for(self,
                 request,
                 swell=False,
                 include_extensions=True,
                 session=None,
                 connection=None,
                 **kw):
        """
        Fills this pull message (response) with versions, operations
        and objects, for the given request (PullRequestMessage).

        The *swell* parameter is deprecated and considered ``True``
        regardless of the value given. This means that parent objects
        will always be added to the message.

        *include_extensions* dictates whether the pull message will
        include model extensions or not.
        """
        assert isinstance(request, PullRequestMessage), "invalid request"
        versions = session.query(Version)
        if request.latest_version_id is not None:
            versions = versions. \
                filter(Version.version_id > request.latest_version_id)
        required_objects = {}
        required_parents = {}

        # TODO: since there can be really many versions here
        # we should rebuild this part so that we make an aggregate query
        # gettting all operations ordered by their version and by their permissions
        # something like
        # select * from
        #     operation, version
        # where
        #     operation.version_id=version.id
        #     and
        #     version.version_id > request.latest_version_id
        #     and
        #     " one of the user's roles is in operation.allowed_users_and_roles "
        # order by
        #     operation.order
        #

        # the basic query can be done here,
        # per dep injection we must add the query for allowed_users

        self.versions = versions.all()
        ops: Query = session.query(Operation)
        if request.latest_version_id is not None:
            ops = ops.filter(Operation.version_id > request.latest_version_id)

        ops = call_filter_operations(connection, session, ops)
        ops = ops.order_by(Operation.order)

        self.operations = []
        logger.info(f"request.latest_version_id = {request.latest_version_id}")
        logger.info(f"querying for {ops}")
        # logger.info(f"query result: {ops.all()}")
        logger.info(f"query result #ops: {len(ops.all())}")
        for op in ops:
            model = op.tracked_model
            if model is None:
                logger.warn(
                    f"op {op} has no model (perhaps removed from tracking)")
                # raise ValueError("operation linked to model %s " \
                #                  "which isn't being tracked" % model)
            if model not in pulled_models: continue
            obj = query_model(session, model).get(op.row_id)
            if obj is None:
                if op.command != 'd':
                    logger.error(
                        f"this should not happen, obj is None for op: {op} - ignoring"
                    )
                    continue
            try:
                call_before_server_add_operation_fn(connection, session, op,
                                                    obj)
                self.operations.append(op)
            except SkipOperation:
                continue

            if op.command != 'd':
                pks = required_objects.get(model, set())
                pks.add(op.row_id)
                required_objects[model] = pks

        for model, pks in ((m, batch)
                           for m, pks in list(required_objects.items())
                           for batch in grouper(pks, MAX_SQL_VARIABLES)):
            for obj in query_model(session, model).filter(
                    getattr(model, get_pk(model)).in_(list(pks))).all():
                self.add_object(obj, include_extensions=include_extensions)
                # add parent objects to resolve conflicts in merge
                for pmodel, ppk in parent_references(
                        obj, list(synched_models.models.keys())):
                    parent_pks = required_parents.get(pmodel, set())
                    parent_pks.add(ppk)
                    required_parents[pmodel] = parent_pks

        for pmodel, ppks in ((m, batch)
                             for m, pks in list(required_parents.items())
                             for batch in grouper(pks, MAX_SQL_VARIABLES)):
            for parent in query_model(session, pmodel).filter(
                    getattr(pmodel, get_pk(pmodel)).in_(list(ppks))).all():
                self.add_object(parent, include_extensions=include_extensions)

        logger.info(f"operations result: {self.operations}")
        return self
示例#40
0
def handle_push(data: Dict[str, Any],
                session: Optional[Session] = None) -> Dict[str, int]:
    """
    Handle the push request and return a dictionary object to be sent
    back to the node.

    If the push is rejected, this procedure will raise a
    dbsync.server.handlers.PushRejected exception.

    *data* must be a dictionary-like object, usually the product of
    parsing a JSON string.
    """
    message: PushMessage
    try:
        message = PushMessage(data)
    except KeyError:
        raise PushRejected("request object isn't a valid PushMessage", data)
    latest_version_id = core.get_latest_version_id(session=session)
    if latest_version_id != message.latest_version_id:
        exc = "version identifier isn't the latest one; "\
            "given: %s" % message.latest_version_id
        if latest_version_id is None:
            raise PushRejected(exc)
        if message.latest_version_id is None:
            raise PullSuggested(exc)
        if message.latest_version_id < latest_version_id:
            raise PullSuggested(exc)
        raise PushRejected(exc)
    if not message.operations:
        return {}
        # raise PushRejected("message doesn't contain operations")
    if not message.islegit(session):
        raise PushRejected("message isn't properly signed")

    for listener in before_push:
        listener(session, message)

    # I) detect unique constraint conflicts and resolve them if possible
    unique_conflicts = find_unique_conflicts(message, session)
    conflicting_objects = set()
    for uc in unique_conflicts:
        obj = uc['object']
        conflicting_objects.add(obj)
        for key, value in zip(uc['columns'], uc['new_values']):
            setattr(obj, key, value)
    for obj in conflicting_objects:
        make_transient(obj)  # remove from session
    for model in set(type(obj) for obj in conflicting_objects):
        pk_name = get_pk(model)
        pks = [
            getattr(obj, pk_name) for obj in conflicting_objects
            if type(obj) is model
        ]
        session.query(model).filter(getattr(model, pk_name).in_(pks)).\
            delete(synchronize_session=False) # remove from the database
    session.add_all(conflicting_objects)  # reinsert
    session.flush()

    # II) perform the operations
    operations = [o for o in message.operations if o.tracked_model is not None]
    try:
        for op in operations:
            op.perform(message, session, message.node_id)
    except OperationError as e:
        logger.exception("Couldn't perform operation in push from node %s.",
                         message.node_id)
        raise PushRejected("at least one operation couldn't be performed",
                           *e.args)

    # III) insert a new version
    version = Version(created=datetime.datetime.now(), node_id=message.node_id)
    session.add(version)

    # IV) insert the operations, discarding the 'order' column
    for op in sorted(operations, key=attr('order')):
        new_op = Operation()
        for k in [k for k in properties_dict(op) if k != 'order']:
            setattr(new_op, k, getattr(op, k))
        session.add(new_op)
        new_op.version = version
        session.flush()

    for listener in after_push:
        listener(session, message)

    # return the new version id back to the node
    return {'new_version_id': version.version_id}
示例#41
0
async def handle_push(connection: Connection,
                      session: sqlalchemy.orm.Session) -> Optional[int]:
    msgs_got = 0
    version: Optional[Version] = None
    async for msg in connection.socket:
        msgs_got += 1
        msg_json = json.loads(msg)
        pushmsg = PushMessage(msg_json)
        # print(f"pushmsg: {msg}")
        if not pushmsg.operations:
            logger.warn("empty operations list in client PushMessage")
        for op in pushmsg.operations:
            logger.info(f"operation: {op}")
        # await connection.socket.send(f"answer is:{msg}")
        logger.info(f"message key={pushmsg.key}")

        latest_version_id = core.get_latest_version_id(session=session)
        logger.info(
            f"** version on server:{latest_version_id}, version in pushmsg:{pushmsg.latest_version_id}"
        )
        if latest_version_id != pushmsg.latest_version_id:
            exc = f"version identifier isn't the latest one; " \
                  f"incoming: {pushmsg.latest_version_id}, on server:{latest_version_id}"

            if latest_version_id is None:
                logger.warn(exc)
                raise PushRejected(exc)
            if pushmsg.latest_version_id is None:
                logger.warn(exc)
                raise PullSuggested(exc)
            if pushmsg.latest_version_id < latest_version_id:
                logger.warn(exc)
                raise PullSuggested(exc)
            raise PushRejected(exc)
        if not pushmsg.islegit(session):
            raise PushRejected("message isn't properly signed")

        for listener in before_push:
            listener(session, pushmsg)

        # I) detect unique constraint conflicts and resolve them if possible
        unique_conflicts = find_unique_conflicts(pushmsg, session)
        conflicting_objects = set()
        for uc in unique_conflicts:
            obj = uc['object']
            conflicting_objects.add(obj)
            for key, value in zip(uc['columns'], uc['new_values']):
                setattr(obj, key, value)
        for obj in conflicting_objects:
            make_transient(obj)  # remove from session
        for model in set(type(obj) for obj in conflicting_objects):
            pk_name = get_pk(model)
            pks = [
                getattr(obj, pk_name) for obj in conflicting_objects
                if type(obj) is model
            ]
            session.query(model).filter(getattr(model, pk_name).in_(pks)). \
                delete(synchronize_session=False)  # remove from the database
        session.add_all(conflicting_objects)  # reinsert
        session.flush()

        # II) perform the operations
        operations = [
            o for o in pushmsg.operations if o.tracked_model is not None
        ]
        post_operations: List[Tuple[Operation, SQLClass,
                                    Optional[SQLClass]]] = []
        try:
            op: Operation
            for op in operations:
                (obj,
                 old_obj) = await op.perform_async(pushmsg, session,
                                                   pushmsg.node_id,
                                                   connection.socket)

                if obj is not None:
                    # if the op has been skipped, it wont be appended for post_operation handling
                    post_operations.append((op, obj, old_obj))

                    resp = dict(type="info",
                                op=dict(
                                    row_id=op.row_id,
                                    version=op.version,
                                    command=op.command,
                                    content_type_id=op.content_type_id,
                                ))
                    call_after_tracking_fn(session, op, obj)
                    await connection.socket.send(json.dumps(resp))

        except OperationError as e:
            logger.exception(
                "Couldn't perform operation in push from node %s.",
                pushmsg.node_id)
            raise PushRejected("at least one operation couldn't be performed",
                               *e.args)

        # III) insert a new version
        if post_operations:  # only if operations have been done -> create the new version
            version = Version(created=datetime.datetime.now(),
                              node_id=pushmsg.node_id)
            session.add(version)

        # IV) insert the operations, discarding the 'order' column
        accomplished_operations = [
            op for (op, obj, old_obj) in post_operations
        ]
        for op in sorted(accomplished_operations, key=attr('order')):
            new_op = Operation()
            for k in [k for k in properties_dict(op) if k != 'order']:
                setattr(new_op, k, getattr(op, k))
            session.add(new_op)
            new_op.version = version
            session.flush()

        for op, obj, old_obj in post_operations:
            op.call_after_operation_fn(session, obj, old_obj)
            # from woodmaster.model.sql.model import WoodPile, Measurement
            # orphans = session.query(Measurement).filter(Measurement.woodpile_id == None).all()
            # print(f"orphans:{orphans}")

        for listener in after_push:
            listener(session, pushmsg)

        # return the new version id back to the client
        logger.info(f"version is: {version}")
        if version:
            await connection.socket.send(
                json.dumps(
                    dict(type="result", new_version_id=version.version_id)))
            return {'new_version_id': version.version_id}
        else:
            await connection.socket.send(
                json.dumps(dict(type="result", new_version_id=None)))
            logger.info("sent nothing message")
            await connection.socket.close()

    logger.info("push ready")