def test_extend(self): @Node.extend() def add(self, lhs, rhs): return lhs + rhs n = Node() self.assertEqual(n.add(4, 2), 6) delattr(Node, 'add') self.assertRaises(AttributeError, lambda: n.add(2, 4))
def test_clone(self): @Node.extend(clone=True) def hack(self, alias): self._negated = True self._alias = alias n = Node() c = n.hack('magic!') self.assertFalse(n._negated) self.assertEqual(n._alias, None) self.assertTrue(c._negated) self.assertEqual(c._alias, 'magic!') class TestModel(Model): data = CharField() hacked = TestModel.data.hack('nugget') self.assertFalse(TestModel.data._negated) self.assertEqual(TestModel.data._alias, None) self.assertTrue(hacked._negated) self.assertEqual(hacked._alias, 'nugget') delattr(Node, 'hack') self.assertRaises(AttributeError, lambda: TestModel.data.hack())
def filter_query(self, query, *args, **kwargs): # normalize args and kwargs into a new expression dq_node = Node() if args: dq_node &= reduce(operator.and_, [a.clone() for a in args]) if kwargs: dq_node &= DQ(**kwargs) # dq_node should now be an Expression, lhs = Node(), rhs = ... q = deque([dq_node]) dq_joins = list() while q: curr = q.popleft() if not isinstance(curr, Expression): continue for side, piece in (('lhs', curr.lhs), ('rhs', curr.rhs)): if isinstance(piece, DQ): query_part, joins = self.convert_dict_to_node(piece.query) dq_joins.extend(joins) expression = reduce(operator.and_, query_part) # Apply values from the DQ object. expression._negated = piece._negated expression._alias = piece._alias setattr(curr, side, expression) else: q.append(piece) dq_node = dq_node.rhs selected = list() query = query.clone() for field, rm in self.remove_dupes(dq_joins): selected.append(rm) if isinstance(field, ForeignKeyField): lm = field.model on = field if isinstance(on, ModelAlias): on = (rm == getattr(rm, rm._meta.primary_key.name)) query = query.ensure_join(lm, rm, on) selected = self.remove_dupes(selected) if query._explicit_selection: query._select += query._model_shorthand(selected) else: selected.insert(0, query.model) query = query.select(*selected) return query.where(dq_node)
def filter(query, filters, alias_map={}): # normalize args and kwargs into a new expression # Note: This is a modified peewee's Query.filter method. # Inner methods convert_dict_to_node and ensure_join also changed. # That is done to support FieldProxy generated from aliases to prevent unnecessary joins (see issue link below). # https://github.com/coleifer/peewee/issues/1338 if filters: dq_node = Node() & DQ(**filters) else: return query # dq_node should now be an Expression, lhs = Node(), rhs = ... q = deque([dq_node]) dq_joins = set() while q: curr = q.popleft() if not isinstance(curr, Expression): continue for side, piece in (('lhs', curr.lhs), ('rhs', curr.rhs)): if isinstance(piece, DQ): new_query, joins = convert_dict_to_node(query, piece.query, alias_map) dq_joins.update(joins) expression = reduce(operator.and_, new_query) # Apply values from the DQ object. expression._negated = piece._negated expression._alias = piece.alias setattr(curr, side, expression) else: q.append(piece) dq_node = dq_node.rhs new_query = query.clone() for field in dq_joins: if isinstance(field, (ForeignKeyField, FieldAlias)): if hasattr(field, 'source'): lm, rm = field.source, field.rel_model else: lm, rm = field.model, field.rel_model field_obj = field elif isinstance(field, BackrefAccessor): lm, rm = field.field.rel_model, field.rel_model field_obj = field.field new_query = ensure_join(new_query, lm, rm, field_obj) return new_query.where(dq_node)