class Updator(object): def __init__(self, spec, array_filters=None): self.update_ops = { # field update ops "$inc": parse_inc, "$min": parse_min, "$max": parse_max, "$mul": parse_mul, "$rename": parse_rename, "$set": parse_set, "$setOnInsert": self.parse_set_on_insert, "$unset": parse_unset, "$currentDate": parse_currentDate, # array update ops # $ implemented in FieldWalker # $[] implemented in FieldWalker # $[<identifier>] implemented in FieldWalker "$addToSet": parse_add_to_set, "$pop": parse_pop, "$pull": parse_pull, "$push": parse_push, "$pullAll": parse_pull_all, "$each": None, "$position": None, "$slice": None, "$sort": None, } self.fields_to_update = [] self.array_filters = self.array_filter_parser(array_filters or []) # sort by key (operator) self.operations = SON(sorted(self.parser(spec).items())) self.__insert = None self.__fieldwalker = None def __repr__(self): pass def __call__(self, fieldwalker, do_insert=False): """Update document and return a bool value indicate changed or not""" self.__fieldwalker = fieldwalker self.__insert = do_insert with fieldwalker: for operator in self.operations.values(): operator(fieldwalker) return fieldwalker.commit() @property def fieldwalker(self): return self.__fieldwalker def array_filter_parser(self, array_filters): filters = {} for i, filter_ in enumerate(array_filters): top = "" conds = {} for identifier, cond in filter_.items(): id_s = identifier.split(".", 1) if not top and id_s[0] in filters: msg = ("Found multiple array filters with the same " "top-level field name {}".format(id_s[0])) raise WriteError(msg, code=9) if top and id_s[0] != top: msg = ("Error parsing array filter: Expected a single " "top-level field name, found {0!r} and {1!r}" "".format(top, id_s[0])) raise WriteError(msg, code=9) top = id_s[0] conds.update({identifier: cond}) filters[top] = QueryFilter(conds) return filters def parser(self, spec): if not next(iter(spec)).startswith("$"): raise ValueError("update only works with $ operators") update_stack = {} idnt_tops = list(self.array_filters.keys()) for op, cmd_doc in spec.items(): if op not in self.update_ops: raise WriteError("Unknown modifier: {}".format(op)) if not is_duckument_type(cmd_doc): msg = ("Modifiers operate on fields but we found type {0} " "instead. For example: {{$mod: {{<field>: ...}}}} " "not {1}".format(type(cmd_doc).__name__, spec)) raise WriteError(msg, code=9) for field, value in cmd_doc.items(): if field == "_id": msg = ("Performing an update on the path '_id' would " "modify the immutable field '_id'") raise WriteError(msg, code=66) for top in list(idnt_tops): if "$[{}]".format(top) in field: idnt_tops.remove(top) update_stack[field] = self.update_ops[op](field, value, self.array_filters) self.check_conflict(field) if op == "$rename": self.check_conflict(value) if idnt_tops: msg = ("The array filter for identifier {0!r} was not " "used in the update {1}".format(idnt_tops[0], spec)) raise WriteError(msg, code=9) return update_stack def check_conflict(self, field): for staged in self.fields_to_update: if field.startswith(staged) or staged.startswith(field): msg = ("Updating the path {0!r} would create a " "conflict at {1!r}".format(field, staged[:len(field)])) raise WriteError(msg, code=40) self.fields_to_update.append(field) def parse_set_on_insert(self, field, value, array_filters): def _set_on_insert(fieldwalker): if self.__insert: parse_set(field, value, array_filters)(fieldwalker) return _set_on_insert