def _validate(self, value, **kwargs):
        Document = _import_class('Document')
        EmbeddedDocument = _import_class('EmbeddedDocument')
        # check choices
        if self.choices:
            is_cls = isinstance(value, (Document, EmbeddedDocument))
            value_to_check = value.__class__ if is_cls else value
            err_msg = 'an instance' if is_cls else 'one'
            if isinstance(self.choices[0], (list, tuple)):
                option_keys = [k for k, v in self.choices]
                if value_to_check not in option_keys:
                    msg = ('Value must be %s of %s' %
                           (err_msg, str(option_keys)))
                    self.error(msg)
            elif value_to_check not in self.choices:
                msg = ('Value must be %s of %s' % (err_msg, str(self.choices)))
                self.error(msg)

        # check validation argument
        if self.validation is not None:
            if isinstance(self.validation, collections.Callable):
                if not self.validation(value):
                    self.error('Value does not match custom validation method')
            else:
                raise ValueError('validation argument for "%s" must be a '
                                 'callable.' % self.name)

        self.validate(value, **kwargs)
    def __setattr__(self, name, value):
        # Handle dynamic data only if an initialised dynamic document
        if self._dynamic and not self._dynamic_lock:

            field = None
            if not hasattr(self, name) and not name.startswith('_'):
                DynamicField = _import_class("DynamicField")
                field = DynamicField(db_field=name)
                field.name = name
                self._dynamic_fields[name] = field

            if not name.startswith('_'):
                value = self.__expand_dynamic_values(name, value)

            # Handle marking data as changed
            if name in self._dynamic_fields:
                self._data[name] = value
                if hasattr(self, '_changed_fields'):
                    self._mark_as_changed(name)

        if (self._is_document and not self._created
                and name in self._meta.get('shard_key', tuple())
                and self._data.get(name) != value):
            OperationError = _import_class('OperationError')
            msg = "Shard Keys are immutable. Tried to update %s" % name
            raise OperationError(msg)

        # Check if the user has created a new instance of a class
        if (self._is_document and self._initialised and self._created
                and name == self._meta['id_field']):
            super(BaseDocument, self).__setattr__('_created', False)

        super(BaseDocument, self).__setattr__(name, value)
    def _get_changed_fields(self, key='', inspected=None):
        """Returns a list of all fields that have explicitly been changed.
        """
        EmbeddedDocument = _import_class("EmbeddedDocument")
        DynamicEmbeddedDocument = _import_class("DynamicEmbeddedDocument")
        _changed_fields = []
        _changed_fields += getattr(self, '_changed_fields', [])

        inspected = inspected or set()
        if hasattr(self, 'id'):
            if self.id in inspected:
                return _changed_fields
            inspected.add(self.id)

        field_list = self._fields.copy()
        if self._dynamic:
            field_list.update(self._dynamic_fields)

        for field_name in field_list:

            db_field_name = self._db_field_map.get(field_name, field_name)
            key = '%s.' % db_field_name
            field = self._data.get(field_name, None)
            if hasattr(field, 'id'):
                if field.id in inspected:
                    continue
                inspected.add(field.id)

            if (isinstance(field, (EmbeddedDocument, DynamicEmbeddedDocument))
                    and db_field_name not in _changed_fields):
                # Find all embedded fields that have been changed
                changed = field._get_changed_fields(key, inspected)
                _changed_fields += ["%s%s" % (key, k) for k in changed if k]
            elif (isinstance(field, (list, tuple, dict))
                  and db_field_name not in _changed_fields):
                # Loop list / dict fields as they contain documents
                # Determine the iterator to use
                if not hasattr(field, 'items'):
                    iterator = enumerate(field)
                else:
                    iterator = iter(field.items())
                for index, value in iterator:
                    if not hasattr(value, '_get_changed_fields'):
                        continue
                    list_key = "%s%s." % (key, index)
                    changed = value._get_changed_fields(list_key, inspected)
                    _changed_fields += [
                        "%s%s" % (list_key, k) for k in changed if k
                    ]
        return _changed_fields
    def validate(self, clean=True):
        """Ensure that all fields' values are valid and that required fields
        are present.
        """
        # Ensure that each field is matched to a valid value
        errors = {}
        if clean:
            try:
                self.clean()
            except ValidationError as error:
                errors[NON_FIELD_ERRORS] = error

        # Get a list of tuples of field names and their current values
        fields = [(field, self._data.get(name))
                  for name, field in list(self._fields.items())]
        if self._dynamic:
            fields += [(field, self._data.get(name))
                       for name, field in list(self._dynamic_fields.items())]

        EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
        GenericEmbeddedDocumentField = _import_class(
            "GenericEmbeddedDocumentField")

        for field, value in fields:
            if value is not None:
                try:
                    if isinstance(
                            field,
                        (EmbeddedDocumentField, GenericEmbeddedDocumentField)):
                        field._validate(value, clean=clean)
                    else:
                        field._validate(value)
                except ValidationError as error:
                    errors[field.name] = error.errors or error
                except (ValueError, AttributeError, AssertionError) as error:
                    errors[field.name] = error
            elif field.required and not getattr(field, '_auto_gen', False):
                errors[field.name] = ValidationError('Field is required',
                                                     field_name=field.name)

        if errors:
            pk = "None"
            if hasattr(self, 'pk'):
                pk = self.pk
            elif self._instance:
                pk = self._instance.pk
            message = "ValidationError (%s:%s) " % (self._class_name, pk)
            raise ValidationError(message, errors=errors)
    def __init__(self, cls):
        """ Construct the no_dereference context manager.

        :param cls: the class to turn dereferencing off on
        """
        self.cls = cls

        ReferenceField = _import_class('ReferenceField')
        GenericReferenceField = _import_class('GenericReferenceField')
        ComplexBaseField = _import_class('ComplexBaseField')

        self.deref_fields = [
            k for k, v in self.cls._fields.items()
            if isinstance(v, (ReferenceField, GenericReferenceField,
                              ComplexBaseField))
        ]
    def __getitem__(self, *args, **kwargs):
        value = super(BaseList, self).__getitem__(*args, **kwargs)

        EmbeddedDocument = _import_class('EmbeddedDocument')
        if isinstance(value, EmbeddedDocument) and value._instance is None:
            value._instance = self._instance
        return value
    def _geo_indices(cls, inspected=None, parent_field=None):
        inspected = inspected or []
        geo_indices = []
        inspected.append(cls)

        geo_field_type_names = [
            "EmbeddedDocumentField", "GeoPointField", "PointField",
            "LineStringField", "PolygonField"
        ]

        geo_field_types = tuple(
            [_import_class(field) for field in geo_field_type_names])

        for field in list(cls._fields.values()):
            if not isinstance(field, geo_field_types):
                continue
            if hasattr(field, 'document_type'):
                field_cls = field.document_type
                if field_cls in inspected:
                    continue
                if hasattr(field_cls, '_geo_indices'):
                    geo_indices += field_cls._geo_indices(
                        inspected, parent_field=field.db_field)
            elif field._geo_index:
                field_name = field.db_field
                if parent_field:
                    field_name = "%s.%s" % (parent_field, field_name)
                geo_indices.append(
                    {'fields': [(field_name, field._geo_index)]})
        return geo_indices
    def _lookup_field(cls, parts):
        """Lookup a field based on its attribute and return a list containing
        the field's parents and the field.
        """
        if not isinstance(parts, (list, tuple)):
            parts = [parts]
        fields = []
        field = None

        for field_name in parts:
            # Handle ListField indexing:
            if field_name.isdigit():
                new_field = field.field
                fields.append(field_name)
                continue

            if field is None:
                # Look up first field from the document
                if field_name == 'pk':
                    # Deal with "primary key" alias
                    field_name = cls._meta['id_field']
                if field_name in cls._fields:
                    field = cls._fields[field_name]
                elif cls._dynamic:
                    DynamicField = _import_class('DynamicField')
                    field = DynamicField(db_field=field_name)
                else:
                    raise LookUpError('Cannot resolve field "%s"' % field_name)
            else:
                ReferenceField = _import_class('ReferenceField')
                GenericReferenceField = _import_class('GenericReferenceField')
                if isinstance(field, (ReferenceField, GenericReferenceField)):
                    raise LookUpError('Cannot perform join in mongoDB: %s' %
                                      '__'.join(parts))
                if hasattr(getattr(field, 'field', None), 'lookup_member'):
                    new_field = field.field.lookup_member(field_name)
                else:
                    # Look up subfield on the previous field
                    new_field = field.lookup_member(field_name)
                if not new_field and isinstance(field, ComplexBaseField):
                    fields.append(field_name)
                    continue
                elif not new_field:
                    raise LookUpError('Cannot resolve field "%s"' % field_name)
                field = new_field  # update field to the new field type
            fields.append(field)
        return fields
    def __get__(self, instance, owner):
        """Descriptor to automatically dereference references.
        """
        if instance is None:
            # Document class being used rather than a document object
            return self

        ReferenceField = _import_class('ReferenceField')
        GenericReferenceField = _import_class('GenericReferenceField')
        dereference = (self._auto_dereference
                       and (self.field is None or isinstance(
                           self.field,
                           (GenericReferenceField, ReferenceField))))

        self._auto_dereference = instance._fields[self.name]._auto_dereference
        if not self.__dereference and instance._initialised and dereference:
            instance._data[self.name] = self._dereference(instance._data.get(
                self.name),
                                                          max_depth=1,
                                                          instance=instance,
                                                          name=self.name)

        value = super(ComplexBaseField, self).__get__(instance, owner)

        # Convert lists / values so we can watch for any changes on them
        if (isinstance(value, (list, tuple))
                and not isinstance(value, BaseList)):
            value = BaseList(value, instance, self.name)
            instance._data[self.name] = value
        elif isinstance(value, dict) and not isinstance(value, BaseDict):
            value = BaseDict(value, instance, self.name)
            instance._data[self.name] = value

        if (self._auto_dereference and instance._initialised
                and isinstance(value, (BaseList, BaseDict))
                and not value._dereferenced):
            value = self._dereference(value,
                                      max_depth=1,
                                      instance=instance,
                                      name=self.name)
            value._dereferenced = True
            instance._data[self.name] = value

        return value
 def _clear_changed_fields(self):
     self._changed_fields = []
     EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
     for field_name, field in self._fields.items():
         if (isinstance(field, ComplexBaseField)
                 and isinstance(field.field, EmbeddedDocumentField)):
             field_value = getattr(self, field_name, None)
             if field_value:
                 for idx in (field_value if isinstance(field_value, dict)
                             else range(len(field_value))):
                     field_value[idx]._clear_changed_fields()
         elif isinstance(field, EmbeddedDocumentField):
             field_value = getattr(self, field_name, None)
             if field_value:
                 field_value._clear_changed_fields()
Exemple #11
0
    def to_python(self, value):
        """Convert a MongoDB-compatible type to a Python type.
        """
        Document = _import_class('Document')

        if isinstance(value, str):
            return value

        if hasattr(value, 'to_python'):
            return value.to_python()

        is_list = False
        if not hasattr(value, 'items'):
            try:
                is_list = True
                value = dict([(k, v) for k, v in enumerate(value)])
            except TypeError:  # Not iterable return the value
                return value

        if self.field:
            value_dict = dict([(key, self.field.to_python(item))
                               for key, item in list(value.items())])
        else:
            value_dict = {}
            for k, v in list(value.items()):
                if isinstance(v, Document):
                    # We need the id from the saved object to create the DBRef
                    if v.pk is None:
                        self.error('You can only reference documents once they'
                                   ' have been saved to the database')
                    collection = v._get_collection_name()
                    value_dict[k] = DBRef(collection, v.pk)
                elif hasattr(v, 'to_python'):
                    value_dict[k] = v.to_python()
                else:
                    value_dict[k] = self.to_python(v)

        if is_list:  # Convert back to a list
            return [
                v for k, v in sorted(list(value_dict.items()),
                                     key=operator.itemgetter(0))
            ]
        return value_dict
Exemple #12
0
    def __get__(self, instance, owner):
        """Descriptor for retrieving a value from a field in a document. Do
        any necessary conversion between Python and MongoDB types.
        """
        if instance is None:
            # Document class being used rather than a document object
            return self
        # Get value from document instance if available, if not use default
        value = instance._data.get(self.name)

        if value is None:
            value = self.default
            # Allow callable default values
            if isinstance(value, collections.Callable):
                value = value()

        EmbeddedDocument = _import_class('EmbeddedDocument')
        if isinstance(value, EmbeddedDocument) and value._instance is None:
            value._instance = weakref.proxy(instance)
        return value
 def _import_classes(cls):
     Document = _import_class('Document')
     EmbeddedDocument = _import_class('EmbeddedDocument')
     DictField = _import_class('DictField')
     return (Document, EmbeddedDocument, DictField)
Exemple #14
0
def query(_doc_cls=None, _field_operation=False, **query):
    """Transform a query from Django-style format to Mongo format.
    """
    mongo_query = {}
    merge_query = defaultdict(list)
    for key, value in sorted(query.items()):
        if key == "__raw__":
            mongo_query.update(value)
            continue

        parts = key.split('__')
        indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()]
        parts = [part for part in parts if not part.isdigit()]
        # Check for an operator and transform to mongo-style if there is
        op = None
        if parts[-1] in MATCH_OPERATORS:
            op = parts.pop()

        negate = False
        if parts[-1] == 'not':
            parts.pop()
            negate = True

        if _doc_cls:
            # Switch field names to proper names [set in Field(name='foo')]
            try:
                fields = _doc_cls._lookup_field(parts)
            except Exception as e:
                raise InvalidQueryError(e)
            parts = []

            cleaned_fields = []
            for field in fields:
                append_field = True
                if isinstance(field, str):
                    parts.append(field)
                    append_field = False
                else:
                    parts.append(field.db_field)
                if append_field:
                    cleaned_fields.append(field)

            # Convert value to proper value
            field = cleaned_fields[-1]

            singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not']
            singular_ops += STRING_OPERATORS
            if op in singular_ops:
                if isinstance(field, str):
                    if (op in STRING_OPERATORS and isinstance(value, str)):
                        StringField = _import_class('StringField')
                        value = StringField.prepare_query_value(op, value)
                    else:
                        value = field
                else:
                    value = field.prepare_query_value(op, value)
            elif op in ('in', 'nin', 'all',
                        'near') and not isinstance(value, dict):
                # 'in', 'nin' and 'all' require a list of values
                value = [field.prepare_query_value(op, v) for v in value]

        # if op and op not in COMPARISON_OPERATORS:
        if op:
            if op in GEO_OPERATORS:
                value = _geo_operator(field, op, value)
            elif op in CUSTOM_OPERATORS:
                if op == 'match':
                    value = {"$elemMatch": value}
                else:
                    NotImplementedError("Custom method '%s' has not "
                                        "been implemented" % op)
            elif op not in STRING_OPERATORS:
                value = {'$' + op: value}

        if negate:
            value = {'$not': value}

        for i, part in indices:
            parts.insert(i, part)
        key = '.'.join(parts)
        if op is None or key not in mongo_query:
            mongo_query[key] = value
        elif key in mongo_query:
            if key in mongo_query and isinstance(mongo_query[key], dict):
                mongo_query[key].update(value)
                # $maxDistance needs to come last - convert to SON
                if '$maxDistance' in mongo_query[key]:
                    value_dict = mongo_query[key]
                    value_son = SON()
                    for k, v in value_dict.items():
                        if k == '$maxDistance':
                            continue
                        value_son[k] = v
                    value_son['$maxDistance'] = value_dict['$maxDistance']
                    mongo_query[key] = value_son
            else:
                # Store for manually merging later
                merge_query[key].append(value)

    # The queryset has been filter in such a way we must manually merge
    for k, v in list(merge_query.items()):
        merge_query[k].append(mongo_query[k])
        del mongo_query[k]
        if isinstance(v, list):
            value = [{k: val} for val in v]
            if '$and' in list(mongo_query.keys()):
                mongo_query['$and'].append(value)
            else:
                mongo_query['$and'] = value

    return mongo_query
    def __init__(self, *args, **values):
        """
        Initialise a document or embedded document

        :param __auto_convert: Try and will cast python objects to Object types
        :param values: A dictionary of values for the document
        """
        if args:
            # Combine positional arguments with named arguments.
            # We only want named arguments.
            field = iter(self._fields_ordered)
            for value in args:
                name = next(field)
                if name in values:
                    raise TypeError("Multiple values for keyword argument '" +
                                    name + "'")
                values[name] = value
        __auto_convert = values.pop("__auto_convert", True)
        signals.pre_init.send(self.__class__, document=self, values=values)

        self._data = {}

        # Assign default values to instance
        for key, field in self._fields.items():
            if self._db_field_map.get(key, key) in values:
                continue
            value = getattr(self, key, None)
            setattr(self, key, value)

        # Set passed values after initialisation
        if self._dynamic:
            self._dynamic_fields = {}
            dynamic_data = {}
            for key, value in values.items():
                if key in self._fields or key == '_id':
                    setattr(self, key, value)
                elif self._dynamic:
                    dynamic_data[key] = value
        else:
            FileField = _import_class('FileField')
            for key, value in values.items():
                if key == '__auto_convert':
                    continue
                key = self._reverse_db_field_map.get(key, key)
                if key in self._fields or key in ('id', 'pk', '_cls'):
                    if __auto_convert and value is not None:
                        field = self._fields.get(key)
                        if field and not isinstance(field, FileField):
                            value = field.to_python(value)
                    setattr(self, key, value)
                else:
                    self._data[key] = value

        # Set any get_fieldname_display methods
        self.__set_field_display()

        if self._dynamic:
            self._dynamic_lock = False
            for key, value in dynamic_data.items():
                setattr(self, key, value)

        # Flag initialised
        self._initialised = True
        signals.post_init.send(self.__class__, document=self)
Exemple #16
0
    def to_mongo(self, value):
        """Convert a Python type to a MongoDB-compatible type.
        """
        Document = _import_class("Document")
        EmbeddedDocument = _import_class("EmbeddedDocument")
        GenericReferenceField = _import_class("GenericReferenceField")

        if isinstance(value, str):
            return value

        if hasattr(value, 'to_mongo'):
            if isinstance(value, Document):
                return GenericReferenceField().to_mongo(value)
            cls = value.__class__
            val = value.to_mongo()
            # If we its a document thats not inherited add _cls
            if (isinstance(value, EmbeddedDocument)):
                val['_cls'] = cls.__name__
            return val

        is_list = False
        if not hasattr(value, 'items'):
            try:
                is_list = True
                value = dict([(k, v) for k, v in enumerate(value)])
            except TypeError:  # Not iterable return the value
                return value

        if self.field:
            value_dict = dict([(key, self.field.to_mongo(item))
                               for key, item in value.items()])
        else:
            value_dict = {}
            for k, v in value.items():
                if isinstance(v, Document):
                    # We need the id from the saved object to create the DBRef
                    if v.pk is None:
                        self.error('You can only reference documents once they'
                                   ' have been saved to the database')

                    # If its a document that is not inheritable it won't have
                    # any _cls data so make it a generic reference allows
                    # us to dereference
                    meta = getattr(v, '_meta', {})
                    allow_inheritance = (meta.get('allow_inheritance',
                                                  ALLOW_INHERITANCE) is True)
                    if not allow_inheritance and not self.field:
                        value_dict[k] = GenericReferenceField().to_mongo(v)
                    else:
                        collection = v._get_collection_name()
                        value_dict[k] = DBRef(collection, v.pk)
                elif hasattr(v, 'to_mongo'):
                    cls = v.__class__
                    val = v.to_mongo()
                    # If we its a document thats not inherited add _cls
                    if (isinstance(v, (Document, EmbeddedDocument))):
                        val['_cls'] = cls.__name__
                    value_dict[k] = val
                else:
                    value_dict[k] = self.to_mongo(v)

        if is_list:  # Convert back to a list
            return [
                v for k, v in sorted(list(value_dict.items()),
                                     key=operator.itemgetter(0))
            ]
        return value_dict
Exemple #17
0
 def _dereference(self, ):
     if not self.__dereference:
         DeReference = _import_class("DeReference")
         self.__dereference = DeReference()  # Cached
     return self.__dereference