Example #1
0
def _create_forward_many_to_many_manager(*args, **kwargs):
    context = inspect.getcallargs(create_forward_many_to_many_manager, *args,
                                  **kwargs)
    manager = create_forward_many_to_many_manager(*args, **kwargs)
    manager.get_queryset = signalify_queryset(
        manager.get_queryset, parser=parse_many_related_queryset, **context)
    return manager
Example #2
0
 def related_manager_cls(self):
     related_model = self.rel.related_model if self.reverse else self.rel.model
     related_manager = create_forward_many_to_many_manager(
         related_model._default_manager.__class__,
         self.rel,
         reverse=self.reverse,
     )
     return create_m2m_ordered_through_manager(related_manager, self.rel)
Example #3
0
def _create_forward_many_to_many_manager(*args, **kwargs):
    context = inspect.getcallargs(create_forward_many_to_many_manager, *args, **kwargs)
    manager = create_forward_many_to_many_manager(*args, **kwargs)
    manager.get_queryset = signalify_queryset(
        manager.get_queryset,
        parser=parse_many_related_queryset,
        **context
    )
    return manager
Example #4
0
 def related_manager_cls(self):
     '''
     Copied from ManyToManyDescriptor.related_manager_cls,
     Replace model._default_manager.__class__ by self.get_manager_base_class(related_model=self.field.remote_field.to)
     '''
     return create_forward_many_to_many_manager(
         self.get_manager_base_class(
             related_model=self.field.remote_field.to),  # Updated
         self.rel,
         reverse=self.reverse,
     )
Example #5
0
    def related_manager_cls(self):
        model = self.rel.related_model if self.reverse else self.rel.model

        def add_custom_queryset_to_many_related_manager(
                many_related_manage_cls):
            class OrderedManyRelatedManager(many_related_manage_cls):
                def get_queryset(self):
                    return super(OrderedManyRelatedManager,
                                 self).get_queryset().order_by(
                                     '%s__position' %
                                     self.through._meta.model_name)

                def add(self, *objs):
                    # Django < 2 doesn't support this method on
                    # ManyToManyFields w/ an intermediary model
                    # We should be able to remove this code snippet when we
                    # upgrade Django.
                    # see: https://github.com/django/django/blob/stable/1.11.x/django/db/models/fields/related_descriptors.py#L926
                    if not django.__version__.startswith('1.'):
                        raise RuntimeError(
                            'This method is no longer necessary in Django>=2')
                    try:
                        self.through._meta.auto_created = True
                        super(OrderedManyRelatedManager, self).add(*objs)
                    finally:
                        self.through._meta.auto_created = False

                def remove(self, *objs):
                    # Django < 2 doesn't support this method on
                    # ManyToManyFields w/ an intermediary model
                    # We should be able to remove this code snippet when we
                    # upgrade Django.
                    # see: https://github.com/django/django/blob/stable/1.11.x/django/db/models/fields/related_descriptors.py#L944
                    if not django.__version__.startswith('1.'):
                        raise RuntimeError(
                            'This method is no longer necessary in Django>=2')
                    try:
                        self.through._meta.auto_created = True
                        super(OrderedManyRelatedManager, self).remove(*objs)
                    finally:
                        self.through._meta.auto_created = False

            return OrderedManyRelatedManager

        return add_custom_queryset_to_many_related_manager(
            create_forward_many_to_many_manager(
                model._default_manager.__class__,
                self.rel,
                reverse=self.reverse,
            ))
Example #6
0
    def related_manager_cls(self):
        model = self.rel.related_model if self.reverse else self.rel.model

        def add_custom_queryset_to_many_related_manager(many_related_manage_cls):
            class OrderedManyRelatedManager(many_related_manage_cls):
                def get_queryset(self):
                    return super(OrderedManyRelatedManager, self).get_queryset().order_by('%s__position' % self.through._meta.model_name)

            return OrderedManyRelatedManager

        return add_custom_queryset_to_many_related_manager(
            create_forward_many_to_many_manager(
                model._default_manager.__class__,
                self.rel,
                reverse=self.reverse,
            )
        )
def create_sorted_forward_many_to_many_manager(superclass, rel, *args, **kwargs):
    RelatedManager = create_forward_many_to_many_manager(superclass, rel, *args, **kwargs)
    print 'create_sorted_forward_many_to_many_manager'

    class ManyRelatedManager(RelatedManager):
        def __call__(self, **kwargs):
            manager = getattr(self.model, kwargs.pop('manager'))
            manager_class = create_sorted_forward_many_to_many_manager(manager.__class__, rel, reverse)
            return manager_class(instance=self.instance)
        do_not_call_in_templates = True

        def get_queryset(self):
            qs = super(ManyRelatedManager, self).get_queryset()
            return qs.extra(order_by=['%s.%s' % (rel.through._meta.db_table, 'id', )])

        get_query_set = get_queryset

        def set(self, objs, **kwargs):
            # Choosing to clear first will ensure the order is maintained.
            kwargs['clear'] = True
            super(ManyRelatedManager, self).set(objs, **kwargs)
        set.alters_data = True

        def _add_items(self, source_field_name, target_field_name, *objs):
            new_ids = [obj.id for obj in objs]
            db = router.db_for_write(self.through, instance=self.instance)

            with transaction.atomic(using=db, savepoint=False):
                if self.reverse or source_field_name == self.source_field_name:
                    signals.m2m_changed.send(sender=self.through, action='pre_add', instance=self.instance, reverse=self.reverse, model=self.model, pk_set=new_ids, using=db)

                self.through._default_manager.using(db).bulk_create([
                    self.through(**{
                        '%s_id' % source_field_name: self.related_val[0],
                        '%s_id' % target_field_name: obj_id,
                    })
                    for obj_id in new_ids
                ])

                if self.reverse or source_field_name == self.source_field_name:
                    signals.m2m_changed.send(sender=self.through, action='post_add', instance=self.instance, reverse=self.reverse, model=self.model, pk_set=new_ids, using=db)

    return ManyRelatedManager
Example #8
0
        class ManyRelatedManager(
                create_forward_many_to_many_manager(
                    related_model._default_manager.__class__,
                    self.rel,
                    reverse=self.reverse,
                )):
            @sync_to_async_threadsafe
            def async_create(self, **kwargs):
                self.create(**kwargs)

            @sync_to_async_threadsafe
            def async_get_or_create(self, *, through_defaults=None, **kwargs):
                self.get_or_create(through_defaults=through_defaults, **kwargs)

            @sync_to_async_threadsafe
            def async_update_or_create(self,
                                       *,
                                       through_defaults=None,
                                       **kwargs):
                return self.update_or_create(through_defaults=through_defaults,
                                             **kwargs)

            @sync_to_async_threadsafe
            def async_clear(self):
                self.clear()

            @sync_to_async_threadsafe
            def async_set(self, objs, *, clear=False, through_defaults=None):
                return self.set(objs,
                                clear=clear,
                                through_defaults=through_defaults)

            @sync_to_async_threadsafe
            def async_remove(self, *args, **kwargs):
                return self.remove(*args, **kwargs)

            @sync_to_async_threadsafe
            def async_add(self, *args, **kwargs):
                return self.add(*args, **kwargs)
def create_many_related_history_manager(superclass, rel, reverse):
    baseManagerClass = create_forward_many_to_many_manager(
        superclass, rel, reverse)

    class ManyToManyHistoryThroughManager(baseManagerClass):

        # time of altering transaction
        time = None

        @property
        def db(self):
            return router.db_for_write(self.through, instance=self.instance)

        @property
        def versions(self):
            return ManyToManyHistoryVersion.objects.filter(
                content_type=ContentType.objects.get_for_model(self.instance),
                object_id=self.instance.pk,
                field_name=self.prefetch_cache_name)

        def get_time(self):
            if not self.time:
                self.time = timezone.now()
            return self.time

        def last_update_time(self):
            # TODO: refactor and optimize this method to one query
            qs = self.get_queryset_through()
            try:
                time_to = qs.exclude(
                    time_to=None).order_by('-time_to')[0].time_to
                time_from = qs.exclude(
                    time_from=None).order_by('-time_from')[0].time_from
                return time_to if time_to > time_from else time_from
            except IndexError:
                return qs.exclude(
                    time_from=None).order_by('-time_from')[0].time_from

        def _prepare_queryset(self, qs, only_pk=False, unique=True):
            qs = qs.values_list(self.target_field_name, flat=True)
            if not only_pk:
                if unique is False:
                    raise ValueError(
                        "Argument `unique` should be True if argument only_pk is False"
                    )
                if StrictVersion(django.get_version()) >= StrictVersion('1.7'):
                    qs = super(ManyToManyHistoryThroughManager,
                               self).get_queryset().using(
                                   self.db).filter(pk__in=qs)
                else:
                    qs = super(ManyToManyHistoryThroughManager,
                               self).get_query_set().using(
                                   self.db).filter(pk__in=qs)

            if unique:
                qs = qs.distinct()
            return qs

        def get_query_set(self, **kwargs):
            # DEPRECATED. backward compatibility
            return self.get_queryset(**kwargs)

        def get_query_set_through(self):
            # DEPRECATED. backward compatibility
            return self.get_queryset_through()

        @property
        def queryset_through(self):
            return self.get_query_set_through()

        def get_queryset_through(self):
            qs = self.through._default_manager.using(
                self.db).filter(**{
                    self.source_field_name: self._fk_val,
                })
            return qs

        def get_queryset(self, **kwargs):
            qs = self.get_queryset_through().filter(time_to=None)
            return self._prepare_queryset(qs, **kwargs)

        def were_between(self, time_from, time_to, **kwargs):
            if time_to <= time_from:
                raise ValueError(
                    'Argument time_to should be later, than time_from')
            qs = self.get_queryset_through().filter(
                Q(time_from=None, time_to=None)
                | Q(time_from=None, time_to__gte=time_to)
                | Q(time_from__lte=time_from, time_to=None)
                | Q(time_from__gte=time_from, time_to__lte=time_to)
                | Q(time_from__lte=time_from, time_to__gte=time_to)
                | Q(time_from__lt=time_to, time_to__gt=time_from))
            return self._prepare_queryset(qs, **kwargs)

        def added_between(self, time_from, time_to, **kwargs):
            if time_to <= time_from:
                raise ValueError(
                    'Argument time_to should be later, than time_from')
            qs = self.get_queryset_through().filter(time_from__gte=time_from,
                                                    time_from__lte=time_to)
            return self._prepare_queryset(qs, **kwargs)

        def removed_between(self, time_from, time_to, **kwargs):
            if time_to <= time_from:
                raise ValueError(
                    'Argument time_to should be later, than time_from')
            qs = self.get_queryset_through().filter(time_to__gte=time_from,
                                                    time_to__lte=time_to)
            return self._prepare_queryset(qs, **kwargs)

        def were_at(self, time, **kwargs):
            qs = self.get_queryset_through().filter(
                Q(time_from=None, time_to=None)
                | Q(time_from=None, time_to__gt=time)
                | Q(time_from__lte=time, time_to=None)
                | Q(time_from__lte=time, time_to__gt=time))
            return self._prepare_queryset(qs, **kwargs)

        def added_at(self, time, **kwargs):
            qs = self.get_queryset_through().filter(time_from=time)
            return self._prepare_queryset(qs, **kwargs)

        def removed_at(self, time, **kwargs):
            qs = self.get_queryset_through().filter(time_to=time)
            return self._prepare_queryset(qs, **kwargs)

        def clear(self, *objs):
            self._clear_items(self.source_field_name, self.target_field_name,
                              *objs)

            # If this is a symmetrical m2m relation to self, clear the mirror entry in the m2m table
            if self.symmetrical:
                self._clear_items(self.target_field_name,
                                  self.source_field_name, *objs)

        clear.alters_data = True

        def send_signal(self, source_field_name, action, ids):
            if self.reverse or source_field_name == self.source_field_name:
                # Don't send the signal when we are inserting the
                # duplicate data row for symmetrical reverse entries.
                signals.m2m_changed.send(sender=self.through,
                                         action=action,
                                         instance=self.instance,
                                         reverse=self.reverse,
                                         model=self.model,
                                         pk_set=ids,
                                         using=self.db)

                m2m_history_changed.send(sender=self.through,
                                         action=action,
                                         instance=self.instance,
                                         reverse=self.reverse,
                                         model=self.model,
                                         pk_set=ids,
                                         using=self.db,
                                         field_name=self.prefetch_cache_name,
                                         time=self.get_time())

        def get_set_of_values(self,
                              objs,
                              target_field_name,
                              check_values=False):
            values = set()
            # Check that all the objects are of the right type
            for obj in objs:
                if isinstance(obj, self.model):
                    if check_values and not router.allow_relation(
                            obj, self.instance):
                        raise ValueError(
                            'Cannot add "%r": instance is on database "%s", value is on database "%s"'
                            % (obj, self.instance._state.db, obj._state.db))
                    fk_val = self._get_fk_val(obj, target_field_name)
                    if check_values and fk_val is None:
                        raise ValueError(
                            'Cannot add "%r": the value for field "%s" is None'
                            % (obj, target_field_name))
                    values.add(fk_val)
                elif isinstance(obj, models.Model):
                    raise TypeError("'%s' instance expected, got %r" %
                                    (self.model._meta.object_name, obj))
                else:
                    values.add(obj)
            return values

        def _add_items(self, source_field_name, target_field_name, *objs):
            # source_field_name: the PK fieldname in join table for the source object
            # target_field_name: the PK fieldname in join table for the target object
            # *objs - objects to add. Either object instances, or primary keys of object instances.

            # If there aren't any objects, there is nothing to do.
            if objs:
                new_ids = self.get_set_of_values(objs,
                                                 target_field_name,
                                                 check_values=True)
                current_ids = self.through._default_manager.using(self.db) \
                    .values_list(target_field_name, flat=True) \
                    .filter(**{
                        source_field_name: self._fk_val,
                        '%s__in' % target_field_name: new_ids,
                        'time_to': None,
                    })
                # remove current from new, otherwise integrity error while bulk_create
                new_ids = new_ids.difference(set(current_ids))

                self.send_signal(source_field_name, 'pre_add', new_ids)

                # Add the ones that aren't there already
                self.through._default_manager.using(self.db).bulk_create([
                    self.through(
                        **{
                            '%s_id' % source_field_name: self._fk_val,
                            '%s_id' % target_field_name: obj_id,
                            'time_from': self.get_time(),
                        }) for obj_id in new_ids
                ])

                self.send_signal(source_field_name, 'post_add', new_ids)

        def _remove_items(self, source_field_name, target_field_name, *objs):
            # source_field_name: the PK colname in join table for the source object
            # target_field_name: the PK colname in join table for the target object
            # *objs - objects to remove

            # If there aren't any objects, there is nothing to do.
            if objs:
                old_ids = self.get_set_of_values(objs, target_field_name)
                self.send_signal(source_field_name, 'pre_remove', old_ids)

                # Remove the specified objects from the join table
                qs = self.through._default_manager.using(self.db).filter(
                    **{
                        source_field_name: self._fk_val,
                        'time_to': None,
                        '%s__in' % target_field_name: old_ids,
                    })
                qs.update(time_to=self.get_time())

                self.send_signal(source_field_name, 'post_remove', old_ids)

        def _clear_items(self, source_field_name, target_field_name, *objs):
            # source_field_name: the PK colname in join table for the source object
            # target_field_name: the PK colname in join table for the target object
            # *objs - objects to clear

            new_ids = self.get_set_of_values(objs,
                                             target_field_name,
                                             check_values=True)
            current_ids = self.through._default_manager.using(self.db) \
                .values_list(target_field_name, flat=True) \
                .filter(**{
                    source_field_name: self._fk_val,
                    'time_to': None,
                })
            old_ids = set(current_ids).difference(new_ids)
            self.send_signal(source_field_name, 'pre_clear', old_ids)

            qs = self.through._default_manager.using(self.db).filter(
                **{
                    source_field_name: self._fk_val,
                    'time_to': None,
                    '%s__in' % target_field_name: old_ids,
                })
            qs.update(time_to=self.get_time())

            self.send_signal(
                source_field_name, 'post_clear',
                set(self.removed_at(self.get_time(), only_pk=True)))

        # compatibility with Django 1.7
        if StrictVersion(django.get_version()) >= StrictVersion('1.7'):

            @property
            def _fk_val(self):
                return self.related_val[0]

            def _get_fk_val(self, obj, target_field_name):
                return self.through._meta.get_field(
                    target_field_name).get_foreign_related_value(obj)[0]

    return ManyToManyHistoryThroughManager
Example #10
0
def create_versioned_forward_many_to_many_manager(superclass,
                                                  rel,
                                                  reverse=None):
    many_related_manager_klass = create_forward_many_to_many_manager(
        superclass, rel, reverse)

    class VersionedManyRelatedManager(many_related_manager_klass):
        def __init__(self, *args, **kwargs):
            super(VersionedManyRelatedManager, self).__init__(*args, **kwargs)
            # Additional core filters are:
            # version_start_date <= t &
            #   (version_end_date > t | version_end_date IS NULL)
            # but we cannot work with the Django core filters, since they
            # don't support ORing filters, which is a thing we need to
            # consider the "version_end_date IS NULL" case;
            # So, we define our own set of core filters being applied when
            # versioning
            try:
                _ = self.through._meta.get_field('version_start_date')
                _ = self.through._meta.get_field('version_end_date')
            except FieldDoesNotExist as e:
                fields = [f.name for f in self.through._meta.get_fields()]
                print(str(e) + "; available fields are " + ", ".join(fields))
                raise e
                # FIXME: this probably does not work when auto-referencing

        def get_queryset(self):
            """
            Add a filter to the queryset, limiting the results to be pointed
            by relationship that are valid for the given timestamp (which is
            taken at the current instance, or set to now, if not available).
            Long story short, apply the temporal validity filter also to the
            intermediary model.
            """
            queryset = super(VersionedManyRelatedManager, self).get_queryset()
            if hasattr(queryset, 'querytime'):
                if self.instance._querytime.active and \
                                self.instance._querytime != queryset.querytime:
                    queryset = queryset.as_of(self.instance._querytime.time)
            return queryset

        def _remove_items(self, source_field_name, target_field_name, *objs):
            """
            Instead of removing items, we simply set the version_end_date of
            the current item to the current timestamp --> t[now].
            Like that, there is no more current entry having that identity -
            which is equal to not existing for timestamps greater than t[now].
            """
            return self._remove_items_at(None, source_field_name,
                                         target_field_name, *objs)

        def _remove_items_at(self, timestamp, source_field_name,
                             target_field_name, *objs):
            if objs:
                if timestamp is None:
                    timestamp = get_utc_now()
                old_ids = set()
                for obj in objs:
                    if isinstance(obj, self.model):
                        # The Django 1.7-way is preferred
                        if hasattr(self, 'target_field'):
                            fk_val = \
                                self.target_field \
                                    .get_foreign_related_value(obj)[0]
                        else:
                            raise TypeError(
                                "We couldn't find the value of the foreign "
                                "key, this might be due to the use of an "
                                "unsupported version of Django")
                        old_ids.add(fk_val)
                    else:
                        old_ids.add(obj)
                db = router.db_for_write(self.through, instance=self.instance)
                qs = self.through._default_manager.using(db).filter(
                    **{
                        source_field_name: self.instance.id,
                        '%s__in' % target_field_name: old_ids
                    }).as_of(timestamp)
                for relation in qs:
                    relation._delete_at(timestamp)

        if 'add' in dir(many_related_manager_klass):

            def add(self, *objs):
                if not self.instance.is_current:
                    raise SuspiciousOperation(
                        "Adding many-to-many related objects is only possible "
                        "on the current version")

                # The ManyRelatedManager.add() method uses the through model's
                # default manager to get a queryset when looking at which
                # objects already exist in the database.
                # In order to restrict the query to the current versions when
                # that is done, we temporarily replace the queryset's using
                # method so that the version validity condition can be
                # specified.
                klass = self.through._default_manager.get_queryset().__class__
                __using_backup = klass.using

                def using_replacement(self, *args, **kwargs):
                    qs = __using_backup(self, *args, **kwargs)
                    return qs.as_of(None)

                klass.using = using_replacement
                super(VersionedManyRelatedManager, self).add(*objs)
                klass.using = __using_backup

            def add_at(self, timestamp, *objs):
                """
                This function adds an object at a certain point in time
                (timestamp)
                """

                # First off, define the new constructor
                def _through_init(self, *args, **kwargs):
                    super(self.__class__, self).__init__(*args, **kwargs)
                    self.version_birth_date = timestamp
                    self.version_start_date = timestamp

                # Through-classes have an empty constructor, so it can easily
                # be overwritten when needed;
                # This is not the default case, so the overwrite only takes
                # place when we "modify the past"
                self.through.__init_backup__ = self.through.__init__
                self.through.__init__ = _through_init

                # Do the add operation
                self.add(*objs)

                # Remove the constructor again (by replacing it with the
                # original empty constructor)
                self.through.__init__ = self.through.__init_backup__
                del self.through.__init_backup__

            add_at.alters_data = True

        if 'remove' in dir(many_related_manager_klass):

            def remove_at(self, timestamp, *objs):
                """
                Performs the act of removing specified relationships at a
                specified time (timestamp);
                So, not the objects at a given time are removed, but their
                relationship!
                """
                self._remove_items_at(timestamp, self.source_field_name,
                                      self.target_field_name, *objs)

                # For consistency, also handle the symmetrical case
                if self.symmetrical:
                    self._remove_items_at(timestamp, self.target_field_name,
                                          self.source_field_name, *objs)

            remove_at.alters_data = True

    return VersionedManyRelatedManager
Example #11
0
def create_sorted_many_related_manager(superclass, rel, *args, **kwargs):
    RelatedManager = create_forward_many_to_many_manager(
        superclass, rel, *args, **kwargs)

    class SortedRelatedManager(RelatedManager):
        def get_queryset(self):
            # We use ``extra`` method here because we have no other access to
            # the extra sorting field of the intermediary model. The fields
            # are hidden for joins because we set ``auto_created`` on the
            # intermediary's meta options.
            try:
                return self.instance._prefetched_objects_cache[
                    self.prefetch_cache_name]
            except (AttributeError, KeyError):
                return super(SortedRelatedManager,
                             self).get_queryset().extra(order_by=[
                                 '%s.%s' % (
                                     rel.through._meta.db_table,
                                     rel.through._sort_field_name,
                                 )
                             ])

        get_query_set = get_queryset

        def get_prefetch_queryset(self, instances, queryset=None):
            result = super(SortedRelatedManager,
                           self).get_prefetch_queryset(instances, queryset)
            queryset = result[0]
            queryset.query.extra_order_by = [
                '%s.%s' %
                (rel.through._meta.db_table, rel.through._sort_field_name)
            ]
            return (queryset, ) + result[1:]

        get_prefetch_query_set = get_prefetch_queryset

        def set(self, objs, **kwargs):
            # Choosing to clear first will ensure the order is maintained.
            kwargs['clear'] = True
            super(SortedRelatedManager, self).set(objs, **kwargs)

        set.alters_data = True

        def _add_items(self, source_field_name, target_field_name, *objs,
                       **kwargs):
            # source_field_name: the PK fieldname in join table for the source object
            # target_field_name: the PK fieldname in join table for the target object
            # *objs - objects to add. Either object instances, or primary keys of object instances.
            through_defaults = kwargs.get('through_defaults') or {}

            if objs:
                # Django uses a set here, we need to use a list to keep the
                # correct ordering.
                new_ids = []
                for obj in objs:
                    if isinstance(obj, self.model):
                        if not router.allow_relation(obj, self.instance):
                            raise ValueError(
                                'Cannot add "%r": instance is on database "%s", value is on database "%s"'
                                %
                                (obj, self.instance._state.db, obj._state.db))
                        fk_val = self.through._meta.get_field(
                            target_field_name).get_foreign_related_value(
                                obj)[0]
                        if fk_val is None:
                            raise ValueError(
                                'Cannot add "%r": the value for field "%s" is None'
                                % (obj, target_field_name))
                        new_ids.append(fk_val)
                    elif isinstance(obj, Model):
                        raise TypeError("'%s' instance expected, got %r" %
                                        (self.model._meta.object_name, obj))
                    else:
                        new_ids.append(obj)

                db = router.db_for_write(self.through, instance=self.instance)
                manager = self.through._default_manager.using(db)
                params = {
                    source_field_name: self.related_val[0],
                    '%s__in' % target_field_name: new_ids
                }
                vals = set(
                    self.through._default_manager.using(db).filter(
                        **params).values_list(target_field_name, flat=True))

                new_ids_set = set(new_ids)
                new_ids_set.difference_update(vals)
                new_ids = [_id for _id in new_ids if _id in new_ids_set]

                if self.reverse or source_field_name == self.source_field_name:
                    # Don't send the signal when we are inserting the
                    # duplicate data row for symmetrical reverse entries.
                    signals.m2m_changed.send(sender=rel.through,
                                             action='pre_add',
                                             instance=self.instance,
                                             reverse=self.reverse,
                                             model=self.model,
                                             pk_set=new_ids_set,
                                             using=db)

                # Add the ones that aren't there already
                with transaction.atomic(using=db):
                    if self.reverse or source_field_name == self.source_field_name:
                        signals.m2m_changed.send(sender=rel.through,
                                                 action='pre_add',
                                                 instance=self.instance,
                                                 reverse=self.reverse,
                                                 model=self.model,
                                                 pk_set=new_ids_set,
                                                 using=db)

                    rel_source_fk = self.related_val[0]
                    sort_field_name = self.through._sort_field_name
                    source_queryset = manager.filter(
                        **{'%s_id' % source_field_name: rel_source_fk})
                    sort_value_max = source_queryset.aggregate(
                        max=Max(sort_field_name))['max'] or 0

                    bulk_data = [{
                        **through_defaults,
                        **{
                            '%s_id' % source_field_name: rel_source_fk,
                            '%s_id' % target_field_name: obj_id,
                            sort_field_name: i,
                        }
                    } for i, obj_id in enumerate(new_ids, sort_value_max + 1)]

                    manager.bulk_create(
                        [self.through(**data) for data in bulk_data])

                    if self.reverse or source_field_name == self.source_field_name:
                        # Don't send the signal when we are inserting the
                        # duplicate data row for symmetrical reverse entries.
                        signals.m2m_changed.send(sender=rel.through,
                                                 action='post_add',
                                                 instance=self.instance,
                                                 reverse=self.reverse,
                                                 model=self.model,
                                                 pk_set=new_ids_set,
                                                 using=db)

    return SortedRelatedManager
Example #12
0
def create_sorted_many_related_manager(superclass, rel, *args, **kwargs):
    RelatedManager = create_forward_many_to_many_manager(
        superclass, rel, *args, **kwargs)

    class SortedRelatedManager(RelatedManager):
        def _apply_rel_ordering(self, queryset):
            return queryset.extra(order_by=['%s.%s' % (
                self.through._meta.db_table,
                self.through._sort_field_name,  # pylint: disable=protected-access
            )])

        def get_queryset(self):
            # We use ``extra`` method here because we have no other access to
            # the extra sorting field of the intermediary model. The fields
            # are hidden for joins because we set ``auto_created`` on the
            # intermediary's meta options.
            try:
                # pylint: disable=protected-access
                return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
            except (AttributeError, KeyError):
                queryset = super(SortedRelatedManager, self).get_queryset()
                return self._apply_rel_ordering(queryset)

        def get_prefetch_queryset(self, instances, queryset=None):
            # Apply the same ordering for prefetch ones
            result = super(SortedRelatedManager, self).get_prefetch_queryset(instances, queryset)
            return (self._apply_rel_ordering(result[0]),) + result[1:]

        def set(self, objs, **kwargs):  # pylint: disable=arguments-differ
            # Choosing to clear first will ensure the order is maintained.
            kwargs['clear'] = True
            super(SortedRelatedManager, self).set(objs, **kwargs)
        set.alters_data = True

        # pylint: disable=arguments-differ
        def _add_items(self, source_field_name, target_field_name, *objs, **kwargs):
            # source_field_name: the PK fieldname in join table for the source object
            # target_field_name: the PK fieldname in join table for the target object
            # *objs - objects to add. Either object instances, or primary keys of object instances.
            # **kwargs: in Django >= 2.2; contains `through_defaults` key.
            through_defaults = kwargs.get('through_defaults') or {}

            # If there aren't any objects, there is nothing to do.
            if objs:
                # Django uses a set here, we need to use a list to keep the
                # correct ordering.
                new_ids = []
                for obj in objs:
                    if isinstance(obj, self.model):
                        if not router.allow_relation(obj, self.instance):
                            raise ValueError(
                                'Cannot add "%r": instance is on database "%s", value is on database "%s"' %
                                (obj, self.instance._state.db, obj._state.db)  # pylint: disable=protected-access
                            )

                        fk_val = self.through._meta.get_field(target_field_name).get_foreign_related_value(obj)[0]

                        if fk_val is None:
                            raise ValueError(
                                'Cannot add "%r": the value for field "%s" is None' %
                                (obj, target_field_name)
                            )

                        new_ids.append(fk_val)
                    elif isinstance(obj, Model):
                        raise TypeError(
                            "'%s' instance expected, got %r" %
                            (self.model._meta.object_name, obj)
                        )
                    else:
                        new_ids.append(obj)

                db = router.db_for_write(self.through, instance=self.instance)
                manager = self.through._default_manager.using(db)  # pylint: disable=protected-access
                vals = (self.through._default_manager.using(db)  # pylint: disable=protected-access
                        .values_list(target_field_name, flat=True)
                        .filter(**{
                            source_field_name: self.related_val[0],
                            '%s__in' % target_field_name: new_ids,
                        }))

                # make set.difference_update() keeping ordering
                new_ids_set = set(new_ids)
                new_ids_set.difference_update(vals)

                new_ids = list(filter(lambda _id: _id in new_ids_set, new_ids))

                # Add the ones that aren't there already
                with transaction.atomic(using=db, savepoint=False):
                    if self.reverse or source_field_name == self.source_field_name:
                        # Don't send the signal when we are inserting the
                        # duplicate data row for symmetrical reverse entries.
                        signals.m2m_changed.send(
                            sender=self.through, action='pre_add',
                            instance=self.instance, reverse=self.reverse,
                            model=self.model, pk_set=new_ids_set, using=db,
                        )

                    rel_source_fk = self.related_val[0]
                    rel_through = self.through
                    sort_field_name = rel_through._sort_field_name  # pylint: disable=protected-access

                    # Use the max of all indices as start index...
                    # maybe an autoincrement field should do the job more efficiently ?
                    source_queryset = manager.filter(**{'%s_id' % source_field_name: rel_source_fk})
                    sort_value_max = source_queryset.aggregate(max=Max(sort_field_name))['max'] or 0

                    bulk_data = [
                        dict(through_defaults, **{
                            '%s_id' % source_field_name: rel_source_fk,
                            '%s_id' % target_field_name: obj_id,
                            sort_field_name: obj_idx,
                        })
                        for obj_idx, obj_id in enumerate(new_ids, sort_value_max + 1)
                    ]

                    manager.bulk_create([rel_through(**data) for data in bulk_data])

                    if self.reverse or source_field_name == self.source_field_name:
                        # Don't send the signal when we are inserting the
                        # duplicate data row for symmetrical reverse entries.
                        signals.m2m_changed.send(
                            sender=self.through, action='post_add',
                            instance=self.instance, reverse=self.reverse,
                            model=self.model, pk_set=new_ids_set, using=db,
                        )

    return SortedRelatedManager
Example #13
0
def create_versioned_forward_many_to_many_manager(superclass, rel,
                                                  reverse=None):
    many_related_manager_klass = create_forward_many_to_many_manager(
        superclass, rel, reverse)

    class VersionedManyRelatedManager(many_related_manager_klass):
        def __init__(self, *args, **kwargs):
            super(VersionedManyRelatedManager, self).__init__(*args, **kwargs)
            # Additional core filters are:
            # version_start_date <= t &
            #   (version_end_date > t | version_end_date IS NULL)
            # but we cannot work with the Django core filters, since they
            # don't support ORing filters, which is a thing we need to
            # consider the "version_end_date IS NULL" case;
            # So, we define our own set of core filters being applied when
            # versioning
            try:
                _ = self.through._meta.get_field('version_start_date')
                _ = self.through._meta.get_field('version_end_date')
            except FieldDoesNotExist as e:
                fields = [f.name for f in self.through._meta.get_fields()]
                print(str(e) + "; available fields are " + ", ".join(fields))
                raise e
                # FIXME: this probably does not work when auto-referencing

        def get_queryset(self):
            """
            Add a filter to the queryset, limiting the results to be pointed
            by relationship that are valid for the given timestamp (which is
            taken at the current instance, or set to now, if not available).
            Long story short, apply the temporal validity filter also to the
            intermediary model.
            """
            queryset = super(VersionedManyRelatedManager, self).get_queryset()
            if hasattr(queryset, 'querytime'):
                if self.instance._querytime.active and \
                                self.instance._querytime != queryset.querytime:
                    queryset = queryset.as_of(self.instance._querytime.time)
            return queryset

        def _remove_items(self, source_field_name, target_field_name, *objs):
            """
            Instead of removing items, we simply set the version_end_date of
            the current item to the current timestamp --> t[now].
            Like that, there is no more current entry having that identity -
            which is equal to not existing for timestamps greater than t[now].
            """
            return self._remove_items_at(None, source_field_name,
                                         target_field_name, *objs)

        def _remove_items_at(self, timestamp, source_field_name,
                             target_field_name, *objs):
            if objs:
                if timestamp is None:
                    timestamp = get_utc_now()
                old_ids = set()
                for obj in objs:
                    if isinstance(obj, self.model):
                        # The Django 1.7-way is preferred
                        if hasattr(self, 'target_field'):
                            fk_val = \
                                self.target_field \
                                    .get_foreign_related_value(obj)[0]
                        else:
                            raise TypeError(
                                "We couldn't find the value of the foreign "
                                "key, this might be due to the use of an "
                                "unsupported version of Django")
                        old_ids.add(fk_val)
                    else:
                        old_ids.add(obj)
                db = router.db_for_write(self.through, instance=self.instance)
                qs = self.through._default_manager.using(db).filter(**{
                    source_field_name: self.instance.id,
                    '%s__in' % target_field_name: old_ids
                }).as_of(timestamp)
                for relation in qs:
                    relation._delete_at(timestamp)

        if 'add' in dir(many_related_manager_klass):
            def add(self, *objs):
                if not self.instance.is_current:
                    raise SuspiciousOperation(
                        "Adding many-to-many related objects is only possible "
                        "on the current version")

                # The ManyRelatedManager.add() method uses the through model's
                # default manager to get a queryset when looking at which
                # objects already exist in the database.
                # In order to restrict the query to the current versions when
                # that is done, we temporarily replace the queryset's using
                # method so that the version validity condition can be
                # specified.
                klass = self.through._default_manager.get_queryset().__class__
                __using_backup = klass.using

                def using_replacement(self, *args, **kwargs):
                    qs = __using_backup(self, *args, **kwargs)
                    return qs.as_of(None)

                klass.using = using_replacement
                super(VersionedManyRelatedManager, self).add(*objs)
                klass.using = __using_backup

            def add_at(self, timestamp, *objs):
                """
                This function adds an object at a certain point in time
                (timestamp)
                """

                # First off, define the new constructor
                def _through_init(self, *args, **kwargs):
                    super(self.__class__, self).__init__(*args, **kwargs)
                    self.version_birth_date = timestamp
                    self.version_start_date = timestamp

                # Through-classes have an empty constructor, so it can easily
                # be overwritten when needed;
                # This is not the default case, so the overwrite only takes
                # place when we "modify the past"
                self.through.__init_backup__ = self.through.__init__
                self.through.__init__ = _through_init

                # Do the add operation
                self.add(*objs)

                # Remove the constructor again (by replacing it with the
                # original empty constructor)
                self.through.__init__ = self.through.__init_backup__
                del self.through.__init_backup__

            add_at.alters_data = True

        if 'remove' in dir(many_related_manager_klass):
            def remove_at(self, timestamp, *objs):
                """
                Performs the act of removing specified relationships at a
                specified time (timestamp);
                So, not the objects at a given time are removed, but their
                relationship!
                """
                self._remove_items_at(timestamp, self.source_field_name,
                                      self.target_field_name, *objs)

                # For consistency, also handle the symmetrical case
                if self.symmetrical:
                    self._remove_items_at(timestamp, self.target_field_name,
                                          self.source_field_name, *objs)

            remove_at.alters_data = True

    return VersionedManyRelatedManager