Ejemplo n.º 1
0
    def _prepare(self, model):
        if self.order_with_respect_to:
            # The app registry will not be ready at this point, so we cannot
            # use get_field().
            query = self.order_with_respect_to
            try:
                self.order_with_respect_to = next(
                    f for f in self._get_fields(reverse=False)
                    if f.name == query or f.attname == query)
            except StopIteration:
                raise FieldDoesNotExist("%s has no field named '%s'" %
                                        (self.object_name, query))

            self.ordering = ('_order', )
            if not any(
                    isinstance(field, OrderWrt)
                    for field in model._meta.local_fields):
                model.add_to_class('_order', OrderWrt())
        else:
            self.order_with_respect_to = None

        if self.pk is None:
            if self.parents:
                # Promote the first parent link in lieu of adding yet another
                # field.
                field = next(iter(self.parents.values()))
                # Look for a local field with the same name as the
                # first parent link. If a local field has already been
                # created, use it instead of promoting the parent
                already_created = [
                    fld for fld in self.local_fields if fld.name == field.name
                ]
                if already_created:
                    field = already_created[0]
                field.primary_key = True
                self.setup_pk(field)
                if not field.remote_field.parent_link:
                    raise ImproperlyConfigured(
                        'Add parent_link=True to %s.' % field, )
            else:
                auto = AutoField(verbose_name='ID',
                                 primary_key=True,
                                 auto_created=True)
                model.add_to_class('id', auto)
Ejemplo n.º 2
0
 def rename_field(self, app_label, model_name, old_name, new_name):
     model_state = self.models[app_label, model_name]
     # Rename the field.
     fields = model_state.fields
     try:
         found = fields.pop(old_name)
     except KeyError:
         raise FieldDoesNotExist(
             f"{app_label}.{model_name} has no field named '{old_name}'")
     fields[new_name] = found
     for field in fields.values():
         # Fix from_fields to refer to the new field.
         from_fields = getattr(field, 'from_fields', None)
         if from_fields:
             field.from_fields = tuple([
                 new_name
                 if from_field_name == old_name else from_field_name
                 for from_field_name in from_fields
             ])
     # Fix index/unique_together to refer to the new field.
     options = model_state.options
     for option in ('index_together', 'unique_together'):
         if option in options:
             options[option] = [[
                 new_name if n == old_name else n for n in together
             ] for together in options[option]]
     # Fix to_fields to refer to the new field.
     delay = True
     references = get_references(self, (app_label, model_name),
                                 (old_name, found))
     for *_, field, reference in references:
         delay = False
         if reference.to:
             remote_field, to_fields = reference.to
             if getattr(remote_field, 'field_name', None) == old_name:
                 remote_field.field_name = new_name
             if to_fields:
                 field.to_fields = tuple([
                     new_name
                     if to_field_name == old_name else to_field_name
                     for to_field_name in to_fields
                 ])
     self.reload_model(app_label, model_name, delay=delay)
Ejemplo n.º 3
0
 def get_field_by_name(self, name):
     """
     Returns the (field_object, model, direct, m2m), where field_object is
     the Field instance for the given name, model is the model containing
     this field (None for local fields), direct is True if the field exists
     on this model, and m2m is True for many-to-many relations. When
     'direct' is False, 'field_object' is the corresponding RelatedObject
     for this field (since the field doesn't have an instance associated
     with it).
     """
     if name in self.document._fields:
         field = self.document._fields[name]
         if isinstance(field, ReferenceField):
             return (field, field.document_type, False, False)
         else:
             return (field, None, True, False)
     else:
         raise FieldDoesNotExist('%s has no field named %r' %
                                 (self.object_name, name))
Ejemplo n.º 4
0
def _get_non_gfk_field(opts, name):
    """
    For historical reasons, the admin app relies on GenericForeignKeys as being
    "not found" by get_field(). This could likely be cleaned up.

    Reverse relations should also be excluded as these aren't attributes of the
    model (rather something like `foo_set`).
    """
    field = opts.get_field(name)
    if (field.is_relation and
            # Generic foreign keys OR reverse relations
            ((field.many_to_one and not field.related_model) or field.one_to_many)):
        raise FieldDoesNotExist()

    # Avoid coercing <FK>_id fields to FK
    if field.is_relation and hasattr(field, 'attname') and field.attname == name:
        raise FieldIsAForeignKeyColumnName()

    return field
    def validate_mapping(self):
        """
        Verify that the mapping provided by the user is acceptable.

        Raises errors if something goes wrong. Returns nothing if everything is kosher.
        """
        # Make sure all of the CSV headers in the mapping actually exist
        for map_header in self.mapping.values():
            if map_header not in self.headers:
                raise ValueError("Header '{}' not found in CSV file".format(map_header))

        # Make sure all the model fields in the mapping actually exist
        for map_field in self.mapping.keys():
            if not self.get_field(map_field):
                raise FieldDoesNotExist("Model does not include {} field".format(map_field))

        # Make sure any static mapping columns exist
        for static_field in self.static_mapping.keys():
            if not self.get_field(static_field):
                raise ValueError("Model does not include {} field".format(static_field))
Ejemplo n.º 6
0
    def get_field_by_name(self, name):
        """
        Returns the (field_object, model, direct, m2m), where field_object is
        the Field instance for the given name, model is the model containing
        this field (None for local fields), direct is True if the field exists
        on this model, and m2m is True for many-to-many relations. When
        'direct' is False, 'field_object' is the corresponding ForeignObjectRel
        for this field (since the field doesn't have an instance associated
        with it).

        Uses a cache internally, so after the first access, this is very fast.
        """
        try:
            try:
                return self._name_map[name]
            except AttributeError:
                cache = self.init_name_map()
                return cache[name]
        except KeyError:
            raise FieldDoesNotExist('%s has no field named %r' %
                                    (self.object_name, name))
Ejemplo n.º 7
0
    def from_db(cls, attributes, objects, many=False):
        """
        ``objects`` is a list of raw LDAP data objects
        ``attributes`` is a
        we need to convert from the raw ldap value to the value our field stores internally

        """
        if not isinstance(objects, list):
            objects = [objects]
        if not many and len(objects) > 1:
            raise RuntimeError('Called {}.from_db() with many=False but len(objects) > 1'.format(cls._meta.object_name))
        _attr_lookup = cls._meta.attribute_to_field_name_map
        _field_lookup = cls._meta.fields_map
        for attr in attributes:
            if attr not in _attr_lookup:
                raise FieldDoesNotExist(
                    'No field on model {} corresponding to LDAP attribute "{}"'.format(cls._meta.object_name, attr)
                )
        rows = []
        for obj in objects:
            if not type(obj[1]) == dict:
                continue
            # Case sensitivity does not matter in LDAP, but it does when we're looking up keys in our dict here.  Deal
            # with the case for when we have a different case on our field name than what LDAP returns
            obj_attr_lookup = {k.lower(): k for k in obj[1]}
            kwargs = {}
            kwargs['_dn'] = obj[0]
            for attr in attributes:
                name = _attr_lookup[attr]
                try:
                    value = obj[1][obj_attr_lookup[attr.lower()]]
                except KeyError:
                    # if the object in LDAP doesn't have that data, the
                    # attribute won't be present in the response
                    continue
                kwargs[name] = _field_lookup[name].from_db_value(value)
            rows.append(cls(**kwargs))
        if not many:
            return rows[0]
        return rows
Ejemplo n.º 8
0
 def state_forwards(self, app_label, state):
     model_state = state.models[app_label, self.model_name_lower]
     # Rename the field
     fields = model_state.fields
     for index, (name, field) in enumerate(fields):
         if name == self.old_name:
             fields[index] = (self.new_name, field)
             # Delay rendering of relationships if it's not a relational field.
             delay = not field.is_relation
             break
     else:
         raise FieldDoesNotExist(
             "%s.%s has no field named '%s'" %
             (app_label, self.model_name, self.old_name))
     # Fix index/unique_together to refer to the new field
     options = model_state.options
     for option in ('index_together', 'unique_together'):
         if option in options:
             options[option] = [[
                 self.new_name if n == self.old_name else n
                 for n in together
             ] for together in options[option]]
     state.reload_model(app_label, self.model_name_lower, delay=delay)
Ejemplo n.º 9
0
    def _get_challenge(cls, relations):
        """
        Retrieve the related challenge of this device.

        Locating the relation is done through the mixin class
        instead of the Challenge model to prevent circular imports.

        :param relations: The relations
        :type relations: tuple

        :return: The related challenge of this device
        :rtype: type of rest_multi_factor.models.mixins.ChallengeMixin
        """
        challenges = tuple(r.related_model for r in relations
                           if issubclass(r.related_model, ChallengeMixin))

        if not challenges:  # pragma: no cover
            raise FieldDoesNotExist("No reverse relation found to a challenge")

        if len(challenges) > 1:  # pragma: no cover
            raise FieldError("Multiple relations to challenges found")

        return challenges[0]
Ejemplo n.º 10
0
    def test_validate_model_field(self, mock_attr, mock_type):
        mock_adapter = create_mock_object(
            DjangoRestAdapter, ['validate_model_field', 'ADAPTER_CONF'])
        field_type = 'foo'
        mock_adapter.extract_model.return_value = None
        mock_meta = mock.Mock()
        mock_model = mock.Mock(_meta=mock_meta)

        # Case A: Extracted model is `None`.
        self.assertRaises(utils.DRFAdapterException,
                          mock_adapter.validate_model_field, mock_adapter,
                          None, 'foo', self.loc, field_type, None)
        mock_attr.assert_not_called
        mock_type.assert_not_called

        # Case B: `_validate_model_type` pass
        mock_type.return_value = '_validate_model_type_called'
        mock_adapter.extract_model.return_value = mock_model
        field, model, automated = mock_adapter.validate_model_field(
            mock_adapter, None, 'foo', self.loc, field_type, None)
        self.assertEqual(field, '_validate_model_type_called')
        self.assertEqual(model, mock_model)
        self.assertTrue(automated)
        mock_type.assert_called_once_with('foo', mock.ANY, field_type)
        mock_attr.assert_not_called

        # Case C: `_validate_model_type` fails and `_validate_model_attribute`
        # is called.
        mock_meta.get_field.side_effect = FieldDoesNotExist()
        mock_attr.return_value = '_validate_model_attribute_called'
        field, model, automated = mock_adapter.validate_model_field(
            mock_adapter, None, 'foo', self.loc, field_type, source='source')
        self.assertEqual(field, '_validate_model_attribute_called')
        self.assertEqual(model, mock_model)
        self.assertFalse(automated)
        mock_attr.assert_called_once_with('foo', mock_model, 'source')
        mock_meta.get_field.assert_called_with('source')
Ejemplo n.º 11
0
def get_form_for_models(
    *args,
    fields=[],
    translations={},
    required=[],
    skip_missing=False,
    formfieldkwargs={},
):
    """Returns a form for the models and fields provided."""
    field_dict = {}
    required = set(required)
    for f in fields:
        translated_field = translations.get(f, f)
        for m in args:
            formfield = None
            try:
                k = formfieldkwargs.get(translated_field, {})
                modelfield = m._meta.get_field(translated_field)
                formfield = modelfield.formfield(**k)
                if formfield is None and modelfield.one_to_one:
                    if "queryset" in k:
                        formfield = forms.ModelChoiceField(**k)
                    else:
                        formfield = forms.ModelChoiceField(
                            queryset=modelfield.target_field.model.objects.all(), **k
                        )
                apply_limit_choices_to_to_formfield(formfield)
                if f in required:
                    formfield.required = True
                break
            except FieldDoesNotExist:
                continue
        if formfield:
            field_dict[f] = formfield
        elif not skip_missing:
            raise FieldDoesNotExist("The field '%s' was not found" % translated_field)
    return type("ModelMuxForm", (forms.Form,), field_dict)
Ejemplo n.º 12
0
def get_verbose_name(an_object, field_name, title_cap=True):
    """Given a model or model instance, return the verbose_name of the model's field.

    If title_cap is True (the default), the verbose_name will be returned with the first letter
    of each word capitalized which makes the verbose_name look nicer in labels.

    If field_name doesn't refer to a model field, raises a FieldDoesNotExist error.
    """
    # get_field() can raise FieldDoesNotExist which I simply propogate up to the caller.
    try:
        field = an_object._meta.get_field(field_name)
    except TypeError:
        # TypeError happens if the caller is very confused and passes an unhashable type such
        # as {} or []. I convert that into a FieldDoesNotExist exception for simplicity.
        raise FieldDoesNotExist("No field named {}".format(str(field_name)))

    verbose_name = field.verbose_name

    if title_cap:
        # Title cap the label using this stackoverflow-approved technique:
        # http://stackoverflow.com/questions/1549641/how-to-capitalize-the-first-letter-of-each-word-in-a-string-python
        verbose_name = " ".join(word.capitalize() for word in verbose_name.split())

    return verbose_name
Ejemplo n.º 13
0
    def __new__(cls, name, bases, attrs):
        new_class = super().__new__(cls, name, bases, attrs)

        opts = new_class._meta

        if not opts.instance_loader_class:
            opts.instance_loader_class = ModelInstanceLoader

        if opts.model:
            model_opts = opts.model._meta
            declared_fields = new_class.fields

            field_list = []
            for f in sorted(model_opts.fields + model_opts.many_to_many):
                if opts.fields is not None and not f.name in opts.fields:
                    continue
                if opts.exclude and f.name in opts.exclude:
                    continue
                if f.name in declared_fields:
                    continue

                field = new_class.field_from_django_field(f.name, f,
                                                          readonly=False)
                field_list.append((f.name, field, ))

            new_class.fields.update(OrderedDict(field_list))

            # add fields that follow relationships
            if opts.fields is not None:
                field_list = []
                for field_name in opts.fields:
                    if field_name in declared_fields:
                        continue
                    if field_name.find('__') == -1:
                        continue

                    model = opts.model
                    attrs = field_name.split('__')
                    for i, attr in enumerate(attrs):
                        verbose_path = ".".join([opts.model.__name__] + attrs[0:i+1])

                        try:
                            f = model._meta.get_field(attr)
                        except FieldDoesNotExist as e:
                            logger.debug(e, exc_info=e)
                            raise FieldDoesNotExist(
                                "%s: %s has no field named '%s'" %
                                (verbose_path, model.__name__, attr))

                        if i < len(attrs) - 1:
                            # We're not at the last attribute yet, so check
                            # that we're looking at a relation, and move on to
                            # the next model.
                            if isinstance(f, ForeignObjectRel):
                                model = get_related_model(f)
                            else:
                                if get_related_model(f) is None:
                                    raise KeyError(
                                        '%s is not a relation' % verbose_path)
                                model = get_related_model(f)

                    if isinstance(f, ForeignObjectRel):
                        f = f.field

                    field = new_class.field_from_django_field(field_name, f,
                                                              readonly=True)
                    field_list.append((field_name, field))

                new_class.fields.update(OrderedDict(field_list))

        return new_class
Ejemplo n.º 14
0
    def _parse(self):

        if self.group_by:
            self.group_by_field = [
                x for x in self.report_model._meta.fields
                if x.name == self.group_by
            ][0]
            self.group_by_model = self.group_by_field.related_model

        self.parsed_columns = []
        for col in self.columns:
            # import pdb; pdb.set_trace()
            attr = getattr(self, col, None)
            if attr:
                col_data = {
                    'name': col,
                    'verbose_name': getattr(attr, 'verbose_name', col),
                    # 'type': 'method',
                    'ref': attr,
                    'type': 'text'
                }
            elif col.startswith('__'):
                # a magic field
                if col in ['__time_series__', '__crosstab__']:
                    #     These are placeholder not real computation field
                    continue

                magic_field_class = field_registry.get_field_by_name(col)
                col_data = {
                    'name': col,
                    'verbose_name': magic_field_class.verbose_name,
                    'source': 'magic_field',
                    'ref': magic_field_class,
                    'type': magic_field_class.type
                }
            else:
                # A database field
                model_to_use = self.group_by_model if self.group_by else self.report_model
                try:
                    if '__' in col:
                        # A traversing link order__client__email
                        field = get_field_from_query_text(col, model_to_use)
                    else:
                        field = model_to_use._meta.get_field(col)
                except FieldDoesNotExist:
                    raise FieldDoesNotExist(
                        f'Field "{col}" not found as an attribute to the generator class, nor as computation field, nor as a database column for the model "{model_to_use._meta.model_name}"'
                    )

                col_data = {
                    'name': col,
                    'verbose_name': field.verbose_name,
                    'source': 'database',
                    'ref': field,
                    'type': field.get_internal_type()
                }
            self.parsed_columns.append(col_data)

            self._parsed_columns = list(self.parsed_columns)
            self._time_series_parsed_columns = self.get_time_series_parsed_columns(
            )
            self._crosstab_parsed_columns = self.get_crosstab_parsed_columns()
Ejemplo n.º 15
0
    def _sanitize_discriminator(cls, name,
                                attrs) -> Union[Discriminator, None]:
        discriminator = attrs["discriminator"]
        if discriminator is None:
            return None

        model = attrs["Meta"].model

        try:
            field = model._meta.get_field(discriminator.discriminator_field)
        except FieldDoesNotExist as exc:
            raise FieldDoesNotExist(
                f"The discriminator field '{discriminator.discriminator_field}' "
                f"does not exist on the model '{model._meta.label}'") from exc

        values_seen = set()

        for value, fields in discriminator.mapping.items():
            # construct a serializer instance if a tuple/list of fields is passed
            if isinstance(fields, (tuple, list)):
                name = f"{value}{model._meta.object_name}Serializer"

                Meta = type("Meta", (), {
                    "model": model,
                    "fields": tuple(fields)
                })

                serializer_class = type(name, (serializers.ModelSerializer, ),
                                        {"Meta": Meta})

                discriminator.mapping[value] = serializer_class()

            values_seen.add(value)

            serializer = discriminator.mapping[value]

            if serializer is None:
                continue

            # rewrite it to nested serializer
            if discriminator.group_field:
                group_name = (
                    f"{discriminator.group_field}_{serializer.__class__.__name__}"
                )
                group_meta = type("Meta", (), {
                    "model": model,
                    "fields": (discriminator.group_field, )
                })

                # find source field for nested serializer
                source = None
                related_fields = model._meta.fields_map
                for field_name, field_type in related_fields.items():
                    if field_type.related_model == serializer.Meta.model:
                        source = field_name

                group_field = serializer.__class__(
                    source=source,
                    required=False,
                    label=discriminator.group_field)

                group_serializer_class = type(
                    group_name,
                    (serializers.ModelSerializer, ),
                    {
                        "Meta": group_meta,
                        discriminator.group_field: group_field
                    },
                )
                discriminator.mapping[value] = group_serializer_class()

        if field.choices:
            values = {choice[0] for choice in field.choices}
            difference = values - values_seen
            if difference:
                logger.warn(
                    "'%s': not all possible values map to a serializer. Missing %s",
                    name,
                    difference,
                )

        return discriminator
Ejemplo n.º 16
0
    def get_field(self, field: str):
        from django.core.exceptions import FieldDoesNotExist

        if field not in {"key": "key", "name": "name"}:
            raise FieldDoesNotExist("not exist")
        return field
Ejemplo n.º 17
0
    def check_columns(
        cls,
        columns,
        group_by,
        report_model,
    ):
        """
        Check and parse the columns, throw errors in case an item in the columns cant not identified
        :param columns: List of columns
        :param group_by: group by field if any
        :param report_model: the report model
        :return: List of dict, each dict contains relevant data to the respective field in `columns`
        """
        group_by_model = None
        if group_by:
            group_by_field = [
                x for x in report_model._meta.get_fields()
                if x.name == group_by.split('__')[0]
            ][0]
            if group_by_field.is_relation:
                group_by_model = group_by_field.related_model
            else:
                group_by_model = report_model

        parsed_columns = []
        for col in columns:
            if col in ['__time_series__', '__crosstab__']:
                #     These are placeholder not real computation field
                continue

            magic_field_class = None
            attr = None

            if type(col) is str:
                attr = getattr(cls, col, None)
            elif issubclass(col, SlickReportField):
                magic_field_class = col

            try:
                magic_field_class = magic_field_class or field_registry.get_field_by_name(
                    col)
            except KeyError:
                magic_field_class = None

            if attr:
                # todo Add testing here
                col_data = {
                    'name': col,
                    'verbose_name': getattr(attr, 'verbose_name', col),
                    # 'type': 'method',
                    'ref': attr,
                    'type': 'text'
                }
            elif magic_field_class:
                # a magic field
                col_data = {
                    'name': magic_field_class.name,
                    'verbose_name': magic_field_class.verbose_name,
                    'source': 'magic_field',
                    'ref': magic_field_class,
                    'type': magic_field_class.type,
                    'is_summable': magic_field_class.is_summable
                }
            else:
                # A database field
                model_to_use = group_by_model if group_by and '__' not in group_by else report_model
                try:
                    if '__' in col:
                        # A traversing link order__client__email
                        field = get_field_from_query_text(col, model_to_use)
                    else:
                        field = model_to_use._meta.get_field(col)
                except FieldDoesNotExist:
                    raise FieldDoesNotExist(
                        f'Field "{col}" not found either as an attribute to the generator class {cls}, '
                        f'or a computation field, or a database column for the model "{model_to_use}"'
                    )

                col_data = {
                    'name': col,
                    'verbose_name': getattr(field, 'verbose_name', col),
                    'source': 'database',
                    'ref': field,
                    'type': field.get_internal_type()
                }
            parsed_columns.append(col_data)
        return parsed_columns
Ejemplo n.º 18
0
 def state_forwards(self, app_label, state):
     model_state = state.models[app_label, self.model_name_lower]
     # Rename the field
     fields = model_state.fields
     found = False
     delay = True
     for index, (name, field) in enumerate(fields):
         if not found and name == self.old_name:
             fields[index] = (self.new_name, field)
             found = True
         # Fix from_fields to refer to the new field.
         from_fields = getattr(field, "from_fields", None)
         if from_fields:
             field.from_fields = tuple(
                 [
                     self.new_name
                     if from_field_name == self.old_name
                     else from_field_name
                     for from_field_name in from_fields
                 ]
             )
         # Delay rendering of relationships if it's not a relational
         # field and not referenced by a foreign key.
         delay = delay and (
             not field.is_relation
             and not is_referenced_by_foreign_key(
                 state, self.model_name_lower, field, self.name
             )
         )
     if not found:
         raise FieldDoesNotExist(
             "%s.%s has no field named '%s'"
             % (app_label, self.model_name, self.old_name)
         )
     # Fix index/unique_together to refer to the new field
     options = model_state.options
     for option in ("index_together", "unique_together"):
         if option in options:
             options[option] = [
                 [self.new_name if n == self.old_name else n for n in together]
                 for together in options[option]
             ]
     # Fix to_fields to refer to the new field.
     model_tuple = app_label, self.model_name_lower
     for (model_app_label, model_name), model_state in state.models.items():
         for index, (name, field) in enumerate(model_state.fields):
             remote_field = field.remote_field
             if remote_field:
                 remote_model_tuple = self._get_model_tuple(
                     remote_field.model, model_app_label, model_name
                 )
                 if remote_model_tuple == model_tuple:
                     if getattr(remote_field, "field_name", None) == self.old_name:
                         remote_field.field_name = self.new_name
                     to_fields = getattr(field, "to_fields", None)
                     if to_fields:
                         field.to_fields = tuple(
                             [
                                 self.new_name
                                 if to_field_name == self.old_name
                                 else to_field_name
                                 for to_field_name in to_fields
                             ]
                         )
     state.reload_model(app_label, self.model_name_lower, delay=delay)
Ejemplo n.º 19
0
    def vector_tile_geom_name(self) -> str:
        for f in self.model._meta.get_fields():
            if isinstance(f, GeometryField):
                return f.name

        raise FieldDoesNotExist()
Ejemplo n.º 20
0
def bulk_sync(
    new_models,
    key_fields,
    filters,
    batch_size=None,
    fields=None,
    exclude_fields=None,
    skip_creates=False,
    skip_updates=False,
    skip_deletes=False,
    db_class=None,
):
    """ Combine bulk create, update, and delete.  Make the DB match a set of in-memory objects.

    `new_models`: Django ORM objects that are the desired state.  They may or may not have `id` set.
    `key_fields`: Identifying attribute name(s) to match up `new_models` items with database rows.  If a foreign key
            is being used as a key field, be sure to pass the `fieldname_id` rather than the `fieldname`.
    `filters`: Q() filters specifying the subset of the database to work in. Use `None` or `[]` if you want to sync
            against the entire table.
    `batch_size`: (optional) passes through to Django `bulk_create.batch_size` and `bulk_update.batch_size`, and controls
            how many objects are created/updated per SQL query.
    `fields`: (optional) list of fields to update. If not set, will sync all fields that are editable and not
            auto-created.
    `exclude_fields`: (optional) list of fields to exclude from updates. Subtracts from the passed-in `fields` or
            default-calculated `fields` (see `fields` documentation above.)
    `skip_creates`: If truthy, will not perform any object creations needed to fully sync. Defaults to not skip.
    `skip_updates`: If truthy, will not perform any object updates needed to fully sync. Defaults to not skip.
    `skip_deletes`: If truthy, will not perform any object deletions needed to fully sync. Defaults to not skip.
    `db_class`: (optional) Model class to operate on. If new_models always contains at least one object, this can
            be set automatically so is optional.
    """

    if db_class is None:
        try:
            db_class = new_models[0].__class__
        except IndexError:
            try:
                db_class = new_models.model
            except AttributeError:
                db_class = None

    if db_class is None:
        raise RuntimeError(
            "Unable to identify model to sync. Need to provide at least one object in `new_models`, provide "
            "`db_class`, or set `new_models` with a queryset like `db_class.objects.none()`."
        )

    if fields is None:
        # Get a list of fields that aren't PKs and aren't editable (e.g. auto_add_now) for bulk_update
        fields = [
            field.name for field in db_class._meta.fields if
            not field.primary_key and not field.auto_created and field.editable
        ]

    if exclude_fields is not None:
        model_fields = set(field.name for field in db_class._meta.fields)
        fields_to_update = set(fields)
        fields_to_exclude = set(exclude_fields)

        # Check that we're not attempting to exclude non-existent fields
        if not fields_to_exclude <= model_fields:
            raise FieldDoesNotExist(
                f'model "{db_class.__name__}" has no field(s) {fields_to_exclude - model_fields}'
            )

        fields = list(fields_to_update - fields_to_exclude)

    using = router.db_for_write(db_class)
    with transaction.atomic(using=using):
        objs = db_class.objects.all()
        if filters:
            objs = objs.filter(filters)
        objs = objs.only("pk", *key_fields).select_for_update()

        prep_functions = defaultdict(lambda: lambda x: x)
        prep_functions.update({
            field.name: functools.partial(field.to_python)
            for field in (db_class._meta.get_field(k) for k in key_fields)
            if hasattr(field, 'to_python')
        })

        def get_key(obj, prep_values=False):
            return tuple(prep_functions[k]
                         (getattr(obj, k)) if prep_values else getattr(obj, k)
                         for k in key_fields)

        obj_dict = {get_key(obj): obj for obj in objs}

        new_objs = []
        existing_objs = []
        for new_obj in new_models:
            old_obj = obj_dict.pop(get_key(new_obj, prep_values=True), None)
            if old_obj is None:
                # This is a new object, so create it.
                new_objs.append(new_obj)
            else:
                new_obj.pk = old_obj.pk
                existing_objs.append(new_obj)

        if not skip_creates:
            db_class.objects.bulk_create(new_objs, batch_size=batch_size)

        if not skip_updates:
            db_class.objects.bulk_update(existing_objs,
                                         fields=fields,
                                         batch_size=batch_size)

        if not skip_deletes:
            # delete stale objects
            objs.filter(pk__in=[_.pk
                                for _ in list(obj_dict.values())]).delete()

        assert len(existing_objs) == len(new_models) - len(new_objs)

        stats = {
            "created": 0 if skip_creates else len(new_objs),
            "updated": 0 if skip_updates else
            (len(new_models) - len(new_objs)),
            "deleted": 0 if skip_deletes else len(obj_dict),
        }

        logger.debug("{}: {} created, {} updated, {} deleted.".format(
            db_class.__name__, stats["created"], stats["updated"],
            stats["deleted"]))

    return {"stats": stats}
Ejemplo n.º 21
0
        def calc_field_names(rel):
            # Extract field names from through model and stores them in
            # rel.through_field (so that they are sent on deconstruct and
            # passed to ModelState instances)

            tf_dict = {}

            if is_fake_model(rel.through):
                # we populate the through field dict using rel.through_fields
                # that was either provided or computed beforehand with the
                # actual model
                for f, k in zip(rel.through_fields,
                                ('src', 'tgt', 'tgt_ct', 'tgt_fk')):
                    tf_dict[k] = f
                rel.through._meta._field_names = tf_dict
                return

            if rel.through_fields:
                tf_dict['src'], tf_dict['tgt'] = \
                    rel.through_fields[:2]
                for gfk in rel.through._meta.private_fields:
                    if gfk.name == tf_dict['tgt']:
                        break
                else:
                    raise FieldDoesNotExist(
                        'Generic foreign key "%s" does not exist in through '
                        'model "%s"' %
                        (tf_dict['tgt'], rel.through._meta.model_name))
                tf_dict['tgt_ct'] = gfk.ct_field
                tf_dict['tgt_fk'] = gfk.fk_field
            else:
                for f in rel.through._meta.fields:
                    try:
                        remote_field = f.remote_field
                    except AttributeError:
                        continue
                    if remote_field and (remote_field.model == rel.field.model
                                         or remote_field.model == '%s.%s' %
                                         (rel.field.model._meta.app_label,
                                          rel.field.model._meta.object_name)):
                        tf_dict['src'] = f.name
                        break
                for f in rel.through._meta.private_fields:
                    if isinstance(f, ct.GenericForeignKey):
                        tf_dict['tgt'] = f.name
                        tf_dict['tgt_ct'] = f.ct_field
                        tf_dict['tgt_fk'] = f.fk_field
                        break

            if not set(tf_dict.keys()).issuperset(('src', 'tgt')):
                raise ValueError('Bad through model for GM2M relationship.')

            rel.through._meta._field_names = tf_dict

            # save the result in rel.through_fields so that it appears
            # in the deconstruction. Without that there would be no way for
            # a ModelState constructed from a migration to know which fields
            # have which function, as all virtual fields are stripped
            tf = []
            for f in ('src', 'tgt', 'tgt_ct', 'tgt_fk'):
                tf.append(tf_dict[f])
            rel.set_init('through_fields', tf)