Example #1
0
    def test_partial_router(self):
        "A router can choose to implement a subset of methods"
        dive = Book.objects.using('other').create(title="Dive into Python",
                                                  published=datetime.date(2009, 5, 4))

        # First check the baseline behaviour

        self.assertEqual(router.db_for_read(User), 'other')
        self.assertEqual(router.db_for_read(Book), 'other')

        self.assertEqual(router.db_for_write(User), 'default')
        self.assertEqual(router.db_for_write(Book), 'default')

        self.assertTrue(router.allow_relation(dive, dive))

        self.assertTrue(router.allow_syncdb('default', User))
        self.assertTrue(router.allow_syncdb('default', Book))

        router.routers = [WriteRouter(), AuthRouter(), TestRouter()]

        self.assertEqual(router.db_for_read(User), 'default')
        self.assertEqual(router.db_for_read(Book), 'other')

        self.assertEqual(router.db_for_write(User), 'writer')
        self.assertEqual(router.db_for_write(Book), 'writer')

        self.assertTrue(router.allow_relation(dive, dive))

        self.assertFalse(router.allow_syncdb('default', User))
        self.assertTrue(router.allow_syncdb('default', Book))
        def add(self, *objs):

            from django.db.models import Model
            if objs:
                new_ids = set()
                for obj in objs:
                    if isinstance(obj, self.model):
                        if not router.allow_relation(obj, self.instance):
                           raise ValueError('Cannot add "%r": instance is on database "%s", value is on database "%s"' %
                                               (obj, self.instance._state.db, obj._state.db))
                        new_ids.add(obj.pk)
                    elif isinstance(obj, Model):
                        raise TypeError("'%s' instance expected" % self.target.field.rel.to._meta.object_name)
                    else:
                        new_ids.add(obj)

                db = router.db_for_write(self.through.__class__, instance=self.instance)
                vals = self.through._default_manager.using(db).values_list(self.source_field_name, flat=True)
                kwargs = {
                    self.source_field_name: self.instance,
                }

                vals = vals.filter(**kwargs)
                new_ids = new_ids - set(vals)

                # Add the ones that aren't there already
                for obj_id in new_ids:
                   self.through._default_manager.using(db).create(**{
                       self.target.ct_field: self.content_type,
                       self.target.fk_field: obj_id,
                       self.source_field_name: self.instance,
                   })
    def __set__(self, instance, value):
        if instance is None:
            raise AttributeError("%s must be accessed via instance" % self.related.opts.object_name)

        # The similarity of the code below to the code in
        # ReverseSingleRelatedObjectDescriptor is annoying, but there's a bunch
        # of small differences that would make a common base class convoluted.

        # If null=True, we can assign null here, but otherwise the value needs
        # to be an instance of the related class.
        if value is None and self.related.field.null == False:
            raise ValueError('Cannot assign None: "%s.%s" does not allow null values.' %
                                (instance._meta.object_name, self.related.get_accessor_name()))
        elif value is not None and not isinstance(value, self.related.model):
            raise ValueError('Cannot assign "%r": "%s.%s" must be a "%s" instance.' %
                                (value, instance._meta.object_name,
                                 self.related.get_accessor_name(), self.related.opts.object_name))
        elif value is not None:
            if instance._state.db is None:
                instance._state.db = router.db_for_write(instance.__class__, instance=value)
            elif value._state.db is None:
                value._state.db = router.db_for_write(value.__class__, instance=instance)
            elif value._state.db is not None and instance._state.db is not None:
                if not router.allow_relation(value, instance):
                    raise ValueError('Cannot assign "%r": instance is on database "%s", value is is on database "%s"' %
                                        (value, instance._state.db, value._state.db))

        # Set the value of the related field to the value of the related object's related field
        setattr(value, self.related.field.attname, getattr(instance, self.related.field.rel.get_related_field().attname))

        # Since we already know what the related object is, seed the related
        # object caches now, too. This avoids another db hit if you get the
        # object you just set.
        setattr(instance, self.cache_name, value)
        setattr(value, self.related.field.get_cache_name(), instance)
def test_router_allow_relation(model):
    obj1 = model()
    obj1._state.db = 'slave1'
    obj2 = model()
    obj2._state.db = 'slave2'

    assert django_router.allow_relation(obj1, obj2)
Example #5
0
        def _add_items(self, source_field_name, target_field_name, *objs):
            # join_table: name of the m2m link table
            # source_field_name: the PK fieldname in join_table for the source object
            # target_field_name: the PK fieldname in join_table for the target object
            # *objs - objects to add. Either object instances, or primary keys of object instances.

            # If there aren't any objects, there is nothing to do.
            from django.db.models import Model
            if objs:
                new_ids = []
                for obj in objs:
                    if isinstance(obj, self.model):
                        if not router.allow_relation(obj, self.instance):
                           raise ValueError('Cannot add "%r": instance is on database "%s", value is on database "%s"' %
                                               (obj, self.instance._state.db, obj._state.db))
                        new_ids.append(obj.pk)
                    elif isinstance(obj, Model):
                        raise TypeError("'%s' instance expected" % self.model._meta.object_name)
                    else:
                        new_ids.append(obj)
                db = router.db_for_write(self.through.__class__, instance=self.instance)
                vals = self.through._default_manager.using(db).values_list(target_field_name, flat=True)
                vals = vals.filter(**{
                    source_field_name: self._pk_val,
                    '%s__in' % target_field_name: new_ids,
                })
                """
                # Denya: i want sorted M2M list, not set
                for val in vals:
                    if val in new_ids:
                        new_ids.remove(val)
                _new_ids = []
                for pk in new_ids:
                    if pk not in _new_ids:
                        _new_ids.append(pk)
                new_ids = _new_ids
                """
                new_ids_set = set(new_ids)
                if self.reverse or source_field_name == self.source_field_name:
                    # Don't send the signal when we are inserting the
                    # duplicate data row for symmetrical reverse entries.
                    signals.m2m_changed.send(sender=rel.through, action='pre_add',
                        instance=self.instance, reverse=self.reverse,
                        model=self.model, pk_set=new_ids_set)
                # Add the ones that aren't there already
                sort_field_name = self.through._sort_field_name
                sort_field = self.through._meta.get_field_by_name(sort_field_name)[0]
                for obj_id in new_ids:
                    self.through._default_manager.using(db).create(**{
                        '%s_id' % source_field_name: self._pk_val,
                        '%s_id' % target_field_name: obj_id,
                        sort_field_name: sort_field.get_default(),
                    })
                if self.reverse or source_field_name == self.source_field_name:
                    # Don't send the signal when we are inserting the
                    # duplicate data row for symmetrical reverse entries.
                    signals.m2m_changed.send(sender=rel.through, action='post_add',
                        instance=self.instance, reverse=self.reverse,
                        model=self.model, pk_set=new_ids_set)
Example #6
0
    def __set__(self, instance, value):
        """
        Set the related instance through the reverse relation.

        With the example above, when setting ``place.restaurant = restaurant``:

        - ``self`` is the descriptor managing the ``restaurant`` attribute
        - ``instance`` is the ``place`` instance
        - ``value`` is the ``restaurant`` instance on the right of the equal sign

        Keep in mind that ``Restaurant`` holds the foreign key to ``Place``.
        """
        # The similarity of the code below to the code in
        # ForwardManyToOneDescriptor is annoying, but there's a bunch
        # of small differences that would make a common base class convoluted.

        if value is None:
            # Update the cached related instance (if any) & clear the cache.
            # Following the example above, this would be the cached
            # ``restaurant`` instance (if any).
            rel_obj = self.related.get_cached_value(instance, default=None)
            if rel_obj is not None:
                # Remove the ``restaurant`` instance from the ``place``
                # instance cache.
                self.related.delete_cached_value(instance)
                # Set the ``place`` field on the ``restaurant``
                # instance to None.
                setattr(rel_obj, self.related.field.name, None)
        elif not isinstance(value, self.related.related_model):
            # An object must be an instance of the related class.
            raise ValueError(
                'Cannot assign "%r": "%s.%s" must be a "%s" instance.' % (
                    value,
                    instance._meta.object_name,
                    self.related.get_accessor_name(),
                    self.related.related_model._meta.object_name,
                )
            )
        else:
            if instance._state.db is None:
                instance._state.db = router.db_for_write(instance.__class__, instance=value)
            elif value._state.db is None:
                value._state.db = router.db_for_write(value.__class__, instance=instance)
            elif value._state.db is not None and instance._state.db is not None:
                if not router.allow_relation(value, instance):
                    raise ValueError('Cannot assign "%r": the current database router prevents this relation.' % value)

            related_pk = tuple(getattr(instance, field.attname) for field in self.related.field.foreign_related_fields)
            # Set the value of the related field to the value of the related object's related field
            for index, field in enumerate(self.related.field.local_related_fields):
                setattr(value, field.attname, related_pk[index])

            # Set the related instance cache used by __get__ to avoid an SQL query
            # when accessing the attribute we just set.
            self.related.set_cached_value(instance, value)

            # Set the forward accessor cache on the related object to the current
            # instance to avoid an extra SQL query if it's accessed later on.
            self.related.field.set_cached_value(value, instance)
                def add_items(self, source_field_name, target_field_name, *objs):
                    # join_table: name of the m2m link table
                    # source_field_name: the PK fieldname in join_table for the source object
                    # target_field_name: the PK fieldname in join_table for the target object
                    # *objs - objects to add. Either object instances, or primary keys of object instances.

                    # If there aren't any objects, there is nothing to do.
                    from django.db.models import Model

                    if objs:
                        new_ids = OrderedSet()
                        for obj in objs:
                            if isinstance(obj, self.model):
                                if not router.allow_relation(obj, self.instance):
                                    raise ValueError(
                                        'Cannot add "%r": instance is on database "%s", value is on database "%s"'
                                        % (obj, self.instance._state.db, obj._state.db)
                                    )
                                new_ids.add(obj.pk)
                            elif isinstance(obj, Model):
                                raise TypeError("'%s' instance expected" % self.model._meta.object_name)
                            else:
                                new_ids.add(obj)
                        db = router.db_for_write(self.through, instance=self.instance)
                        vals = self.through._default_manager.using(db).values_list(target_field_name, flat=True)
                        vals = vals.filter(**{source_field_name: self._pk_val, "%s__in" % target_field_name: new_ids})
                        new_ids = new_ids - set(vals)

                        if self.reverse or source_field_name == self.source_field_name:
                            # Don't send the signal when we are inserting the
                            # duplicate data row for symmetrical reverse entries.
                            signals.m2m_changed.send(
                                sender=rel.through,
                                action="pre_add",
                                instance=self.instance,
                                reverse=self.reverse,
                                model=self.model,
                                pk_set=new_ids,
                                using=db,
                            )
                        # Add the ones that aren't there already
                        for obj_id in new_ids:
                            self.through._default_manager.using(db).create(
                                **{"%s_id" % source_field_name: self._pk_val, "%s_id" % target_field_name: obj_id}
                            )
                        if self.reverse or source_field_name == self.source_field_name:
                            # Don't send the signal when we are inserting the
                            # duplicate data row for symmetrical reverse entries.
                            signals.m2m_changed.send(
                                sender=rel.through,
                                action="post_add",
                                instance=self.instance,
                                reverse=self.reverse,
                                model=self.model,
                                pk_set=new_ids,
                                using=db,
                            )
    def __set__(self, instance, value):
        if instance is None:
            raise AttributeError("%s must be accessed via instance" % self._field.name)

        # If null=True, we can assign null here, but otherwise the value needs
        # to be an instance of the related class.
        if value is None and self.field.null == False:
            raise ValueError('Cannot assign None: "%s.%s" does not allow null values.' %
                                (instance._meta.object_name, self.field.name))
        elif value is not None and not isinstance(value, self.field.rel.to):
            raise ValueError('Cannot assign "%r": "%s.%s" must be a "%s" instance.' %
                                (value, instance._meta.object_name,
                                 self.field.name, self.field.rel.to._meta.object_name))
        elif value is not None:
            if instance._state.db is None:
                instance._state.db = router.db_for_write(instance.__class__, instance=value)
            elif value._state.db is None:
                value._state.db = router.db_for_write(value.__class__, instance=instance)
            elif value._state.db is not None and instance._state.db is not None:
                if not router.allow_relation(value, instance):
                    raise ValueError('Cannot assign "%r": instance is on database "%s", value is is on database "%s"' %
                                        (value, instance._state.db, value._state.db))

        # If we're setting the value of a OneToOneField to None, we need to clear
        # out the cache on any old related object. Otherwise, deleting the
        # previously-related object will also cause this object to be deleted,
        # which is wrong.
        if value is None:
            # Look up the previously-related object, which may still be available
            # since we've not yet cleared out the related field.
            # Use the cache directly, instead of the accessor; if we haven't
            # populated the cache, then we don't care - we're only accessing
            # the object to invalidate the accessor cache, so there's no
            # need to populate the cache just to expire it again.
            related = getattr(instance, self.field.get_cache_name(), None)

            # If we've got an old related object, we need to clear out its
            # cache. This cache also might not exist if the related object
            # hasn't been accessed yet.
            if related:
                cache_name = self.field.related.get_cache_name()
                try:
                    delattr(related, cache_name)
                except AttributeError:
                    pass

        # Set the value of the related field
        try:
            val = getattr(value, self.field.rel.get_related_field().attname)
        except AttributeError:
            val = None
        setattr(instance, self.field.attname, val)

        # Since we already know what the related object is, seed the related
        # object cache now, too. This avoids another db hit if you get the
        # object you just set.
        setattr(instance, self.field.get_cache_name(), value)
    def __set__(self, instance, value):
        """
        Set the related instance through the reverse relation.

        With the example above, when setting ``place.restaurant = restaurant``:

        - ``self`` is the descriptor managing the ``restaurant`` attribute
        - ``instance`` is the ``place`` instance
        - ``value`` in the ``restaurant`` instance on the right of the equal sign

        Keep in mind that ``Restaurant`` holds the foreign key to ``Place``.
        """
        # The similarity of the code below to the code in
        # ForwardManyToOneDescriptor is annoying, but there's a bunch
        # of small differences that would make a common base class convoluted.

        # If null=True, we can assign null here, but otherwise the value needs
        # to be an instance of the related class.
        if value is None and self.related.field.null is False:
            raise ValueError(
                'Cannot assign None: "%s.%s" does not allow null values.' % (
                    instance._meta.object_name,
                    self.related.get_accessor_name(),
                )
            )
        elif value is not None and not isinstance(value, self.related.related_model):
            raise ValueError(
                'Cannot assign "%r": "%s.%s" must be a "%s" instance.' % (
                    value,
                    instance._meta.object_name,
                    self.related.get_accessor_name(),
                    self.related.related_model._meta.object_name,
                )
            )
        elif value is not None:
            if instance._state.db is None:
                instance._state.db = router.db_for_write(instance.__class__, instance=value)
            elif value._state.db is None:
                value._state.db = router.db_for_write(value.__class__, instance=instance)
            elif value._state.db is not None and instance._state.db is not None:
                if not router.allow_relation(value, instance):
                    raise ValueError('Cannot assign "%r": the current database router prevents this relation.' % value)

        related_pk = tuple(getattr(instance, field.attname) for field in self.related.field.foreign_related_fields)
        # Set the value of the related field to the value of the related object's related field
        for index, field in enumerate(self.related.field.local_related_fields):
            setattr(value, field.attname, related_pk[index])

        # Set the related instance cache used by __get__ to avoid a SQL query
        # when accessing the attribute we just set.
        setattr(instance, self.cache_name, value)

        # Set the forward accessor cache on the related object to the current
        # instance to avoid an extra SQL query if it's accessed later on.
        setattr(value, self.related.field.get_cache_name(), instance)
Example #10
0
        def _add_items(self, source_field_name, target_field_name, *objs, **kwargs):
            # source_field_name: the PK fieldname in join table for the source object
            # target_field_name: the PK fieldname in join table for the target object
            # *objs - objects to add. Either object instances, or primary keys of object instances.
            # **kwargs: in Django >= 2.2; contains `through_defaults` key.
            through_defaults = kwargs.get("through_defaults") or {}

            # If there aren't any objects, there is nothing to do.
            from django.db.models import Max, Model

            if objs:
                # Django uses a set here, we need to use a list to keep the
                # correct ordering.
                new_ids = []
                for obj in objs:
                    if isinstance(obj, self.model):
                        if not router.allow_relation(obj, self.instance):
                            raise ValueError(
                                'Cannot add "%r": instance is on database "%s", value is on database "%s"'
                                % (obj, self.instance._state.db, obj._state.db)
                            )

                        fk_val = self.through._meta.get_field(
                            target_field_name
                        ).get_foreign_related_value(obj)[0]

                        if fk_val is None:
                            raise ValueError(
                                'Cannot add "%r": the value for field "%s" is None'
                                % (obj, target_field_name)
                            )

                        new_ids.append(fk_val)
                    elif isinstance(obj, Model):
                        raise TypeError(
                            "'%s' instance expected, got %r"
                            % (self.model._meta.object_name, obj)
                        )
                    else:
                        new_ids.append(obj)

                db = router.db_for_write(self.through, instance=self.instance)
                manager = self.through._default_manager.using(db)
                vals = (
                    self.through._default_manager.using(db)
                    .values_list(target_field_name, flat=True)
                    .filter(
                        **{
                            source_field_name: self.related_val[0],
                            "%s__in" % target_field_name: new_ids,
                        }
                    )
                )

                # make set.difference_update() keeping ordering
                new_ids_set = set(new_ids)
                new_ids_set.difference_update(vals)

                new_ids = list(filter(lambda _id: _id in new_ids_set, new_ids))

                # Add the ones that aren't there already
                with transaction.atomic(using=db, savepoint=False):
                    if self.reverse or source_field_name == self.source_field_name:
                        # Don't send the signal when we are inserting the
                        # duplicate data row for symmetrical reverse entries.
                        signals.m2m_changed.send(
                            sender=self.through,
                            action="pre_add",
                            instance=self.instance,
                            reverse=self.reverse,
                            model=self.model,
                            pk_set=new_ids_set,
                            using=db,
                        )

                    rel_source_fk = self.related_val[0]
                    rel_through = self.through
                    sort_field_name = rel_through._sort_field_name

                    # Use the max of all indices as start index...
                    # maybe an autoincrement field should do the job more efficiently ?
                    source_queryset = manager.filter(
                        **{"%s_id" % source_field_name: rel_source_fk}
                    )
                    sort_value_max = (
                        source_queryset.aggregate(max=Max(sort_field_name))["max"] or 0
                    )

                    bulk_data = [
                        dict(
                            through_defaults,
                            **{
                                "%s_id" % source_field_name: rel_source_fk,
                                "%s_id" % target_field_name: obj_id,
                                sort_field_name: obj_idx,
                            }
                        )
                        for obj_idx, obj_id in enumerate(new_ids, sort_value_max + 1)
                    ]

                    manager.bulk_create([rel_through(**data) for data in bulk_data])

                    if self.reverse or source_field_name == self.source_field_name:
                        # Don't send the signal when we are inserting the
                        # duplicate data row for symmetrical reverse entries.
                        signals.m2m_changed.send(
                            sender=self.through,
                            action="post_add",
                            instance=self.instance,
                            reverse=self.reverse,
                            model=self.model,
                            pk_set=new_ids_set,
                            using=db,
                        )
Example #11
0
    def __set__(self, instance, value):
        """
        Set the related instance through the forward relation.

        With the example above, when setting ``child.parent = parent``:

        - ``self`` is the descriptor managing the ``parent`` attribute
        - ``instance`` is the ``child`` instance
        - ``value`` is the ``parent`` instance on the right of the equal sign
        """
        # An object must be an instance of the related class.
        if value is not None and not isinstance(value, self.field.remote_field.model._meta.concrete_model):
            raise ValueError(
                'Cannot assign "%r": "%s.%s" must be a "%s" instance.' % (
                    value,
                    instance._meta.object_name,
                    self.field.name,
                    self.field.remote_field.model._meta.object_name,
                )
            )
        elif value is not None:
            if instance._state.db is None:
                instance._state.db = router.db_for_write(instance.__class__, instance=value)
            elif value._state.db is None:
                value._state.db = router.db_for_write(value.__class__, instance=instance)
            elif value._state.db is not None and instance._state.db is not None:
                if not router.allow_relation(value, instance):
                    raise ValueError('Cannot assign "%r": the current database router prevents this relation.' % value)

        remote_field = self.field.remote_field
        # If we're setting the value of a OneToOneField to None, we need to clear
        # out the cache on any old related object. Otherwise, deleting the
        # previously-related object will also cause this object to be deleted,
        # which is wrong.
        if value is None:
            # Look up the previously-related object, which may still be available
            # since we've not yet cleared out the related field.
            # Use the cache directly, instead of the accessor; if we haven't
            # populated the cache, then we don't care - we're only accessing
            # the object to invalidate the accessor cache, so there's no
            # need to populate the cache just to expire it again.
            related = self.field.get_cached_value(instance, default=None)

            # If we've got an old related object, we need to clear out its
            # cache. This cache also might not exist if the related object
            # hasn't been accessed yet.
            if related is not None:
                remote_field.set_cached_value(related, None)

            for lh_field, rh_field in self.field.related_fields:
                setattr(instance, lh_field.attname, None)

        # Set the values of the related field.
        else:
            for lh_field, rh_field in self.field.related_fields:
                setattr(instance, lh_field.attname, getattr(value, rh_field.attname))

        # Set the related instance cache used by __get__ to avoid an SQL query
        # when accessing the attribute we just set.
        self.field.set_cached_value(instance, value)

        # If this is a one-to-one relation, set the reverse accessor cache on
        # the related object to the current instance to avoid an extra SQL
        # query if it's accessed later on.
        if value is not None and not remote_field.multiple:
            remote_field.set_cached_value(value, instance)
Example #12
0
        def _add_items(self, source_field_name, target_field_name, *objs):
            # source_field_name: the PK fieldname in join table for the source object
            # target_field_name: the PK fieldname in join table for the target object
            # *objs - objects to add. Either object instances, or primary keys of object instances.

            # If there aren't any objects, there is nothing to do.
            from django.db.models import Model
            if objs:
                new_ids = set()
                for obj in objs:
                    if isinstance(obj, self.model):
                        if not router.allow_relation(obj, self.instance):
                            raise ValueError(
                                'Cannot add "%r": instance is on database "%s", value is on database "%s"' %
                                (obj, self.instance._state.db, obj._state.db)
                            )
                        fk_val = self.through._meta.get_field(
                            target_field_name).get_foreign_related_value(obj)[0]
                        if fk_val is None:
                            raise ValueError(
                                'Cannot add "%r": the value for field "%s" is None' %
                                (obj, target_field_name)
                            )
                        new_ids.add(fk_val)
                    elif isinstance(obj, Model):
                        raise TypeError(
                            "'%s' instance expected, got %r" %
                            (self.model._meta.object_name, obj)
                        )
                    else:
                        new_ids.add(obj)

                db = router.db_for_write(self.through, instance=self.instance)
                vals = (self.through._default_manager.using(db)
                        .values_list(target_field_name, flat=True)
                        .filter(**{
                            source_field_name: self.related_val[0],
                            '%s__in' % target_field_name: new_ids,
                        }))
                new_ids.difference_update(vals)

                with transaction.atomic(using=db, savepoint=False):
                    if self.reverse or source_field_name == self.source_field_name:
                        # Don't send the signal when we are inserting the
                        # duplicate data row for symmetrical reverse entries.
                        signals.m2m_changed.send(
                            sender=self.through, action='pre_add',
                            instance=self.instance, reverse=self.reverse,
                            model=self.model, pk_set=new_ids, using=db,
                        )

                    # Add the ones that aren't there already
                    self.through._default_manager.using(db).bulk_create([
                        self.through(**{
                            '%s_id' % source_field_name: self.related_val[0],
                            '%s_id' % target_field_name: obj_id,
                        })
                        for obj_id in new_ids
                    ])

                    if self.reverse or source_field_name == self.source_field_name:
                        # Don't send the signal when we are inserting the
                        # duplicate data row for symmetrical reverse entries.
                        signals.m2m_changed.send(
                            sender=self.through, action='post_add',
                            instance=self.instance, reverse=self.reverse,
                            model=self.model, pk_set=new_ids, using=db,
                        )
Example #13
0
        def _add_items(self, source_field_name, target_field_name, *objs):
            # source_field_name: the PK fieldname in join table for the source object
            # target_field_name: the PK fieldname in join table for the target object
            # *objs - objects to add. Either object instances, or primary keys of object instances.

            # If there aren't any objects, there is nothing to do.
            from django.db.models import Model
            if objs:
                # Django uses a set here, we need to use a list to keep the
                # correct ordering.
                new_ids = []
                for obj in objs:
                    if isinstance(obj, self.model):
                        if not router.allow_relation(obj, self.instance):
                            raise ValueError(
                                'Cannot add "%r": instance is on database "%s", value is on database "%s"' %
                                (obj, self.instance._state.db, obj._state.db)
                            )
                        if hasattr(self, '_get_fk_val'):  # Django>=1.5
                            fk_val = self._get_fk_val(obj, target_field_name)
                            if fk_val is None:
                                raise ValueError('Cannot add "%r": the value for field "%s" is None' %
                                                 (obj, target_field_name))
                            new_ids.append(self._get_fk_val(obj, target_field_name))
                        else:  # Django<1.5
                            new_ids.append(obj.pk)
                    elif isinstance(obj, Model):
                        raise TypeError(
                            "'%s' instance expected, got %r" %
                            (self.model._meta.object_name, obj)
                        )
                    else:
                        new_ids.append(obj)

                db = router.db_for_write(self.through, instance=self.instance)
                vals = (self.through._default_manager.using(db)
                        .values_list(target_field_name, flat=True)
                        .filter(**{
                            source_field_name: self._fk_val,
                            '%s__in' % target_field_name: new_ids,
                        }))
                for val in vals:
                    if val in new_ids:
                        new_ids.remove(val)
                _new_ids = []
                for pk in new_ids:
                    if pk not in _new_ids:
                        _new_ids.append(pk)
                new_ids = _new_ids
                new_ids_set = set(new_ids)

                with atomic(using=db, savepoint=False):
                    if self.reverse or source_field_name == self.source_field_name:
                        # Don't send the signal when we are inserting the
                        # duplicate data row for symmetrical reverse entries.
                        signals.m2m_changed.send(sender=rel.through, action='pre_add',
                            instance=self.instance, reverse=self.reverse,
                            model=self.model, pk_set=new_ids_set, using=db)
                    # Add the ones that aren't there already
                    sort_field_name = self.through._sort_field_name
                    sort_field = self.through._meta.get_field_by_name(sort_field_name)[0]
                    for obj_id in new_ids:
                        self.through._default_manager.using(db).create(**{
                            '%s_id' % source_field_name: self._fk_val,  # Django 1.5 compatibility
                            '%s_id' % target_field_name: obj_id,
                            sort_field_name: sort_field.get_default(),
                        })
                    if self.reverse or source_field_name == self.source_field_name:
                        # Don't send the signal when we are inserting the
                        # duplicate data row for symmetrical reverse entries.
                        signals.m2m_changed.send(sender=rel.through, action='post_add',
                            instance=self.instance, reverse=self.reverse,
                            model=self.model, pk_set=new_ids_set, using=db)
            def _add_items(self, source_field_name, target_field_name, *objs):
                """
                By default, auto_created through objects from form instances are saved using
                Manager.bulk_create(). Manager.bulk_create() is passed a list containing
                instances of the through model with the target and source foreign keys defined.

                In order to set the position field we need to tweak this logic (the modified
                lines are marked out with comments below).

                This method is added to ManyRelatedManager below in
                SortableDescriptorMixin.related_manager_cls
                """
                # source_field_name: the PK fieldname in join table for the source object
                # target_field_name: the PK fieldname in join table for the target object
                # *objs - objects to add. Either object instances, or primary keys of object instances.

                # If there aren't any objects, there is nothing to do.
                from django.db.models import Model
                if objs:
                    new_ids = set()
                    for obj in objs:
                        if isinstance(obj, self.model):
                            if not router.allow_relation(obj, self.instance):
                                raise ValueError('Cannot add "%r": instance is on database "%s", value is on database "%s"' %
                                                   (obj, self.instance._state.db, obj._state.db))
                            new_ids.add(obj.pk)
                        elif isinstance(obj, Model):
                            raise TypeError("'%s' instance expected, got %r" % (self.model._meta.object_name, obj))
                        else:
                            new_ids.add(obj)
                    db = router.db_for_write(self.through, instance=self.instance)
                    vals = self.through._default_manager.using(db).values_list(target_field_name, flat=True)
                    vals = vals.filter(**{
                        source_field_name: self._pk_val,
                        '%s__in' % target_field_name: new_ids,
                    })
                    new_ids = new_ids - set(vals)

                    if self.reverse or source_field_name == self.source_field_name:
                        # Don't send the signal when we are inserting the
                        # duplicate data row for symmetrical reverse entries.
                        signals.m2m_changed.send(sender=self.through, action='pre_add',
                            instance=self.instance, reverse=self.reverse,
                            model=self.model, pk_set=new_ids, using=db)

                    ######################################################################
                    # This is where we modify the default logic for _add_items().
                    # We use get_or_create for ALL objects. Typically it calls bulk_create
                    # ONLY on ids which have not yet been created.
                    ######################################################################
                    # sort_field = self.field.sort_field
                    sort_field_attname = self.field.sort_field.attname
                    for obj in objs:
                        sort_position = getattr(obj, sort_field_attname)
                        new_obj, created = self.through._default_manager.using(db).get_or_create(**{
                            sort_field_attname: sort_position,
                            '%s_id' % source_field_name: self._pk_val,
                            '%s_id' % target_field_name: obj.pk,
                        })
                        if getattr(new_obj, sort_field_attname) is not sort_position:
                            setattr(new_obj, sort_field_attname, sort_position)
                            new_obj.save()
                    ######################################################################
                    # End custom logic
                    ######################################################################

                    if self.reverse or source_field_name == self.source_field_name:
                        # Don't send the signal when we are inserting the
                        # duplicate data row for symmetrical reverse entries.
                        signals.m2m_changed.send(sender=self.through, action='post_add',
                            instance=self.instance, reverse=self.reverse,
                            model=self.model, pk_set=new_ids, using=db)
Example #15
0
        def _add_items(self,
                       source_field_name,
                       target_field_name,
                       *objs,
                       through_defaults=None):
            # source_field_name: the PK fieldname in join table for the source object
            # target_field_name: the PK fieldname in join table for the target object
            # *objs - objects to add. Either object instances, or primary keys of object instances.
            through_defaults = through_defaults or {}

            # If there aren't any objects, there is nothing to do.
            from django.db.models import Max, Model
            if objs:
                # Django uses a set here, we need to use a list to keep the
                # correct ordering.
                new_ids = []
                for obj in objs:
                    if isinstance(obj, self.model):
                        if not router.allow_relation(obj, self.instance):
                            raise ValueError(
                                'Cannot add "%r": instance is on database "%s", value is on database "%s"'
                                %
                                (obj, self.instance._state.db, obj._state.db))
                        fk_val = self.through._meta.get_field(
                            target_field_name).get_foreign_related_value(
                                obj)[0]
                        if fk_val is None:
                            raise ValueError(
                                'Cannot add "%r": the value for field "%s" is None'
                                % (obj, target_field_name))
                        new_ids.append(fk_val)
                    elif isinstance(obj, Model):
                        raise TypeError("'%s' instance expected, got %r" %
                                        (self.model._meta.object_name, obj))
                    else:
                        new_ids.append(obj)

                db = router.db_for_write(self.through, instance=self.instance)
                manager = self.through._default_manager.using(db)
                vals = (manager.values_list(
                    target_field_name, flat=True).filter(
                        **{
                            source_field_name: self._fk_val,
                            '%s__in' % target_field_name: new_ids,
                        }))
                for val in vals:
                    if val in new_ids:
                        new_ids.remove(val)
                _new_ids = []
                for pk in new_ids:
                    if pk not in _new_ids:
                        _new_ids.append(pk)
                new_ids = _new_ids
                new_ids_set = set(new_ids)

                if self.reverse or source_field_name == self.source_field_name:
                    # Don't send the signal when we are inserting the
                    # duplicate data row for symmetrical reverse entries.
                    signals.m2m_changed.send(sender=rel.through,
                                             action='pre_add',
                                             instance=self.instance,
                                             reverse=self.reverse,
                                             model=self.model,
                                             pk_set=new_ids_set,
                                             using=db)

                # Add the ones that aren't there already
                with atomic(using=db):
                    fk_val = self._fk_val
                    source_queryset = manager.filter(
                        **{'%s_id' % source_field_name: fk_val})
                    sort_field_name = self.through._sort_field_name
                    sort_value_max = source_queryset.aggregate(
                        max=Max(sort_field_name))['max'] or 0

                    manager.bulk_create([
                        self.through(
                            **through_defaults, **{
                                '%s_id' % source_field_name: fk_val,
                                '%s_id' % target_field_name: pk,
                                sort_field_name: sort_value_max + i + 1,
                            }) for i, pk in enumerate(new_ids)
                    ])

                if self.reverse or source_field_name == self.source_field_name:
                    # Don't send the signal when we are inserting the
                    # duplicate data row for symmetrical reverse entries.
                    signals.m2m_changed.send(sender=rel.through,
                                             action='post_add',
                                             instance=self.instance,
                                             reverse=self.reverse,
                                             model=self.model,
                                             pk_set=new_ids_set,
                                             using=db)
Example #16
0
        def _add_items(self, source_field_name, target_field_name, *objs):
            # source_field_name: the PK fieldname in join table for the source object
            # target_field_name: the PK fieldname in join table for the target object
            # *objs - objects to add. Either object instances, or primary keys of object instances.

            # If there aren't any objects, there is nothing to do.
            from django.db.models import Model
            if objs:
                # Django uses a set here, we need to use a list to keep the
                # correct ordering.
                new_ids = []
                for obj in objs:
                    if isinstance(obj, self.model):
                        if not router.allow_relation(obj, self.instance):
                            raise ValueError(
                                'Cannot add "%r": instance is on database "%s", value is on database "%s"'
                                %
                                (obj, self.instance._state.db, obj._state.db))
                        if hasattr(self, '_get_fk_val'):  # Django>=1.5
                            fk_val = self._get_fk_val(obj, target_field_name)
                            if fk_val is None:
                                raise ValueError(
                                    'Cannot add "%r": the value for field "%s" is None'
                                    % (obj, target_field_name))
                            new_ids.append(
                                self._get_fk_val(obj, target_field_name))
                        else:  # Django<1.5
                            new_ids.append(obj.pk)
                    elif isinstance(obj, Model):
                        raise TypeError("'%s' instance expected, got %r" %
                                        (self.model._meta.object_name, obj))
                    else:
                        new_ids.append(obj)

                db = router.db_for_write(self.through, instance=self.instance)
                vals = (self.through._default_manager.using(db).values_list(
                    target_field_name, flat=True).filter(
                        **{
                            source_field_name: self._fk_val,
                            '%s__in' % target_field_name: new_ids,
                        }))
                for val in vals:
                    if val in new_ids:
                        new_ids.remove(val)
                _new_ids = []
                for pk in new_ids:
                    if pk not in _new_ids:
                        _new_ids.append(pk)
                new_ids = _new_ids
                new_ids_set = set(new_ids)

                if self.reverse or source_field_name == self.source_field_name:
                    # Don't send the signal when we are inserting the
                    # duplicate data row for symmetrical reverse entries.
                    signals.m2m_changed.send(sender=rel.through,
                                             action='pre_add',
                                             instance=self.instance,
                                             reverse=self.reverse,
                                             model=self.model,
                                             pk_set=new_ids_set,
                                             using=db)
                # Add the ones that aren't there already
                sort_field_name = self.through._sort_field_name
                sort_field = self.through._meta.get_field_by_name(
                    sort_field_name)[0]
                if django.VERSION < (1, 6):
                    for obj_id in new_ids:
                        self.through._default_manager.using(db).create(
                            **{
                                '%s_id' % source_field_name:
                                self._fk_val,  # Django 1.5 compatibility
                                '%s_id' % target_field_name: obj_id,
                                sort_field_name: sort_field.get_default(),
                            })
                else:
                    with transaction.atomic():
                        sort_field_default = sort_field.get_default()
                        self.through._default_manager.using(db).bulk_create([
                            self.through(
                                **{
                                    '%s_id' % source_field_name: self._fk_val,
                                    '%s_id' % target_field_name: v,
                                    sort_field_name: sort_field_default + i,
                                }) for i, v in enumerate(new_ids)
                        ])
                if self.reverse or source_field_name == self.source_field_name:
                    # Don't send the signal when we are inserting the
                    # duplicate data row for symmetrical reverse entries.
                    signals.m2m_changed.send(sender=rel.through,
                                             action='post_add',
                                             instance=self.instance,
                                             reverse=self.reverse,
                                             model=self.model,
                                             pk_set=new_ids_set,
                                             using=db)
        def _add_items(self, source_field_name, target_field_name, *objs):
            # source_field_name: the PK fieldname in join table for the source object
            # target_field_name: the PK fieldname in join table for the target object
            # *objs - objects to add. Either object instances, or primary keys of object instances.

            # If there aren't any objects, there is nothing to do.
            from django.db.models import Model
            if objs:
                new_ids = set()
                for obj in objs:
                    if isinstance(obj, self.model):
                        if not router.allow_relation(obj, self.instance):
                            raise ValueError(
                                'Cannot add "%r": instance is on database "%s", value is on database "%s"'
                                %
                                (obj, self.instance._state.db, obj._state.db))
                        fk_val = self.through._meta.get_field(
                            target_field_name).get_foreign_related_value(
                                obj)[0]
                        if fk_val is None:
                            raise ValueError(
                                'Cannot add "%r": the value for field "%s" is None'
                                % (obj, target_field_name))
                        new_ids.add(fk_val)
                    elif isinstance(obj, Model):
                        raise TypeError("'%s' instance expected, got %r" %
                                        (self.model._meta.object_name, obj))
                    else:
                        new_ids.add(obj)

                db = router.db_for_write(self.through, instance=self.instance)
                vals = (self.through._default_manager.using(db).values_list(
                    target_field_name, flat=True).filter(
                        **{
                            source_field_name: self.related_val[0],
                            '%s__in' % target_field_name: new_ids,
                        }))
                new_ids = new_ids - set(vals)

                with transaction.atomic(using=db, savepoint=False):
                    if self.reverse or source_field_name == self.source_field_name:
                        # Don't send the signal when we are inserting the
                        # duplicate data row for symmetrical reverse entries.
                        signals.m2m_changed.send(sender=self.through,
                                                 action='pre_add',
                                                 instance=self.instance,
                                                 reverse=self.reverse,
                                                 model=self.model,
                                                 pk_set=new_ids,
                                                 using=db)

                    # Add the ones that aren't there already
                    self.through._default_manager.using(db).bulk_create([
                        self.through(
                            **{
                                '%s_id' % source_field_name:
                                self.related_val[0],
                                '%s_id' % target_field_name: obj_id,
                            }) for obj_id in new_ids
                    ])

                    if self.reverse or source_field_name == self.source_field_name:
                        # Don't send the signal when we are inserting the
                        # duplicate data row for symmetrical reverse entries.
                        signals.m2m_changed.send(sender=self.through,
                                                 action='post_add',
                                                 instance=self.instance,
                                                 reverse=self.reverse,
                                                 model=self.model,
                                                 pk_set=new_ids,
                                                 using=db)
            def _add_items(self, source_field_name, target_field_name, *objs):
                """
                By default, auto_created through objects from form instances are saved using
                Manager.bulk_create(). Manager.bulk_create() is passed a list containing
                instances of the through model with the target and source foreign keys defined.

                In order to set the position field we need to tweak this logic (the modified
                lines are marked out with comments below).

                This method is added to ManyRelatedManager below in
                SortableDescriptorMixin.related_manager_cls
                """
                # source_field_name: the PK fieldname in join table for the source object
                # target_field_name: the PK fieldname in join table for the target object
                # *objs - objects to add. Either object instances, or primary keys of object instances.

                # If there aren't any objects, there is nothing to do.
                from django.db.models import Model
                if objs:
                    new_ids = set()
                    for obj in objs:
                        if isinstance(obj, self.model):
                            if not router.allow_relation(obj, self.instance):
                                raise ValueError('Cannot add "%r": instance is on database "%s", value is on database "%s"' %
                                                   (obj, self.instance._state.db, obj._state.db))
                            # _get_fk_val wasn't introduced until django 1.4.2
                            if hasattr(self, '_get_fk_val'):
                                fk_val = self._get_fk_val(obj, target_field_name)
                            else:
                                fk_val = obj.pk
                            if fk_val is None:
                                raise ValueError('Cannot add "%r": the value for field "%s" is None' %
                                                 (obj, target_field_name))
                            new_ids.add(fk_val)
                        elif isinstance(obj, Model):
                            raise TypeError("'%s' instance expected, got %r" % (self.model._meta.object_name, obj))
                        else:
                            new_ids.add(obj)
                    db = router.db_for_write(self.through, instance=self.instance)
                    vals = self.through._default_manager.using(db).values_list(target_field_name, flat=True)
                    vals = vals.filter(**{
                        source_field_name: getattr(self, '_pk_val', getattr(self, '_fk_val', self.instance.pk)),
                        '%s__in' % target_field_name: new_ids,
                    })
                    new_ids = new_ids - set(vals)

                    if self.reverse or source_field_name == self.source_field_name:
                        # Don't send the signal when we are inserting the
                        # duplicate data row for symmetrical reverse entries.
                        signals.m2m_changed.send(sender=self.through, action='pre_add',
                            instance=self.instance, reverse=self.reverse,
                            model=self.model, pk_set=new_ids, using=db)

                    ######################################################################
                    # This is where we modify the default logic for _add_items().
                    # We use get_or_create for ALL objects. Typically it calls bulk_create
                    # ONLY on ids which have not yet been created.
                    ######################################################################
                    # sort_field = self.field.sort_field
                    sort_field_attname = self.field.sort_field.attname
                    for obj in objs:
                        sort_position = getattr(obj, sort_field_attname)
                        new_obj, created = self.through._default_manager.using(db).get_or_create(**{
                            sort_field_attname: sort_position,
                            '%s_id' % source_field_name: getattr(self, '_pk_val', getattr(self, '_fk_val', self.instance.pk)),
                            '%s_id' % target_field_name: obj.pk,
                        })
                        if getattr(new_obj, sort_field_attname) is not sort_position:
                            setattr(new_obj, sort_field_attname, sort_position)
                            new_obj.save()
                    ######################################################################
                    # End custom logic
                    ######################################################################

                    if self.reverse or source_field_name == self.source_field_name:
                        # Don't send the signal when we are inserting the
                        # duplicate data row for symmetrical reverse entries.
                        signals.m2m_changed.send(sender=self.through, action='post_add',
                            instance=self.instance, reverse=self.reverse,
                            model=self.model, pk_set=new_ids, using=db)
    def __set__(self, instance, value):
        """
        Set the related instance through the forward relation.

        With the example above, when setting ``child.parent = parent``:

        - ``self`` is the descriptor managing the ``parent`` attribute
        - ``instance`` is the ``child`` instance
        - ``value`` in the ``parent`` instance on the right of the equal sign
        """
        # If null=True, we can assign null here, but otherwise the value needs
        # to be an instance of the related class.
        if value is None and self.field.null is False:
            raise ValueError(
                'Cannot assign None: "%s.%s" does not allow null values.' %
                (instance._meta.object_name, self.field.name))
        elif value is not None and not isinstance(
                value, self.field.remote_field.model._meta.concrete_model):
            raise ValueError(
                'Cannot assign "%r": "%s.%s" must be a "%s" instance.' % (
                    value,
                    instance._meta.object_name,
                    self.field.name,
                    self.field.remote_field.model._meta.object_name,
                ))
        elif value is not None:
            if instance._state.db is None:
                instance._state.db = router.db_for_write(instance.__class__,
                                                         instance=value)
            elif value._state.db is None:
                value._state.db = router.db_for_write(value.__class__,
                                                      instance=instance)
            elif value._state.db is not None and instance._state.db is not None:
                if not router.allow_relation(value, instance):
                    raise ValueError(
                        'Cannot assign "%r": the current database router prevents this relation.'
                        % value)

        # If we're setting the value of a OneToOneField to None, we need to clear
        # out the cache on any old related object. Otherwise, deleting the
        # previously-related object will also cause this object to be deleted,
        # which is wrong.
        if value is None:
            # Look up the previously-related object, which may still be available
            # since we've not yet cleared out the related field.
            # Use the cache directly, instead of the accessor; if we haven't
            # populated the cache, then we don't care - we're only accessing
            # the object to invalidate the accessor cache, so there's no
            # need to populate the cache just to expire it again.
            related = getattr(instance, self.cache_name, None)

            # If we've got an old related object, we need to clear out its
            # cache. This cache also might not exist if the related object
            # hasn't been accessed yet.
            if related is not None:
                setattr(related, self.field.remote_field.get_cache_name(),
                        None)

            for lh_field, rh_field in self.field.related_fields:
                setattr(instance, lh_field.attname, None)

        # Set the values of the related field.
        else:
            for lh_field, rh_field in self.field.related_fields:
                setattr(instance, lh_field.attname,
                        getattr(value, rh_field.attname))

        # Set the related instance cache used by __get__ to avoid a SQL query
        # when accessing the attribute we just set.
        setattr(instance, self.cache_name, value)

        # If this is a one-to-one relation, set the reverse accessor cache on
        # the related object to the current instance to avoid an extra SQL
        # query if it's accessed later on.
        if value is not None and not self.field.remote_field.multiple:
            setattr(value, self.field.remote_field.get_cache_name(), instance)
    def __set__(self, instance, value):
        """
        Set the related instance through the reverse relation.

        With the example above, when setting ``place.restaurant = restaurant``:

        - ``self`` is the descriptor managing the ``restaurant`` attribute
        - ``instance`` is the ``place`` instance
        - ``value`` in the ``restaurant`` instance on the right of the equal sign

        Keep in mind that ``Restaurant`` holds the foreign key to ``Place``.
        """
        # The similarity of the code below to the code in
        # ForwardManyToOneDescriptor is annoying, but there's a bunch
        # of small differences that would make a common base class convoluted.

        # If null=True, we can assign null here, but otherwise the value needs
        # to be an instance of the related class.
        if value is None:
            if self.related.field.null:
                # Update the cached related instance (if any) & clear the cache.
                try:
                    rel_obj = getattr(instance, self.cache_name)
                except AttributeError:
                    pass
                else:
                    delattr(instance, self.cache_name)
                    setattr(rel_obj, self.related.field.name, None)
            else:
                raise ValueError(
                    'Cannot assign None: "%s.%s" does not allow null values.' %
                    (
                        instance._meta.object_name,
                        self.related.get_accessor_name(),
                    ))
        elif not isinstance(value, self.related.related_model):
            raise ValueError(
                'Cannot assign "%r": "%s.%s" must be a "%s" instance.' % (
                    value,
                    instance._meta.object_name,
                    self.related.get_accessor_name(),
                    self.related.related_model._meta.object_name,
                ))
        else:
            if instance._state.db is None:
                instance._state.db = router.db_for_write(instance.__class__,
                                                         instance=value)
            elif value._state.db is None:
                value._state.db = router.db_for_write(value.__class__,
                                                      instance=instance)
            elif value._state.db is not None and instance._state.db is not None:
                if not router.allow_relation(value, instance):
                    raise ValueError(
                        'Cannot assign "%r": the current database router prevents this relation.'
                        % value)

            related_pk = tuple(
                getattr(instance, field.attname)
                for field in self.related.field.foreign_related_fields)
            # Set the value of the related field to the value of the related object's related field
            for index, field in enumerate(
                    self.related.field.local_related_fields):
                setattr(value, field.attname, related_pk[index])

            # Set the related instance cache used by __get__ to avoid a SQL query
            # when accessing the attribute we just set.
            setattr(instance, self.cache_name, value)

            # Set the forward accessor cache on the related object to the current
            # instance to avoid an extra SQL query if it's accessed later on.
            setattr(value, self.related.field.get_cache_name(), instance)