Esempio n. 1
0
 def test_add(self):
     # start with the same children of node1 then add an item
     node3 = Node(self.node1_children)
     node3_added_child = ("c", 3)
     # add() returns the added data
     self.assertEqual(node3.add(node3_added_child, Node.default), node3_added_child)
     # we added exactly one item, len() should reflect that
     self.assertEqual(len(self.node1) + 1, len(node3))
     self.assertEqual(str(node3), "(DEFAULT: ('a', 1), ('b', 2), ('c', 3))")
Esempio n. 2
0
 def __init__(self, model):
     self.model = model
     self.where = Node()
     self.order_by = []
     self.fields = set()
     self.default_ordering = True
     self.select_related = True  #Hack for admin-interface: django/contrib/admin/views/main.py line 210
     self._field_names = {}
     self.target_collection_prefix = None
Esempio n. 3
0
 def test_hash(self):
     node3 = Node(self.node1_children, negated=True)
     node4 = Node(self.node1_children, connector='OTHER')
     node5 = Node(self.node1_children)
     self.assertNotEqual(hash(self.node1), hash(self.node2))
     self.assertNotEqual(hash(self.node1), hash(node3))
     self.assertNotEqual(hash(self.node1), hash(node4))
     self.assertEqual(hash(self.node1), hash(node5))
     self.assertEqual(hash(self.node2), hash(Node()))
Esempio n. 4
0
 def test_add(self):
     # start with the same children of node1 then add an item
     node3 = Node(self.node1_children)
     node3_added_child = ('c', 3)
     # add() returns the added data
     self.assertEqual(node3.add(node3_added_child, Node.default),
                      node3_added_child)
     # we added exactly one item, len() should reflect that
     self.assertEqual(len(self.node1) + 1, len(node3))
     self.assertEqual(str(node3), "(DEFAULT: ('a', 1), ('b', 2), ('c', 3))")
Esempio n. 5
0
 def test_copy(self):
     a = Node([Node(["a", "b"], OR), "c"], AND)
     b = copy.copy(a)
     self.assertEqual(a, b)
     # Children lists are the same object.
     self.assertIs(a.children, b.children)
     # Child Node objects are the same objects.
     for a_child, b_child in zip(a.children, b.children):
         if isinstance(a_child, Node):
             self.assertIs(a_child, b_child)
         self.assertEqual(a_child, b_child)
Esempio n. 6
0
class NodeTests(unittest.TestCase):
    def setUp(self):
        self.node1_children = [('a', 1), ('b', 2)]
        self.node1 = Node(self.node1_children)
        self.node2 = Node()

    def test_str(self):
        self.assertEqual(str(self.node1), "(DEFAULT: ('a', 1), ('b', 2))")
        self.assertEqual(str(self.node2), "(DEFAULT: )")

    def test_repr(self):
        self.assertEqual(repr(self.node1),
                         "<Node: (DEFAULT: ('a', 1), ('b', 2))>")
        self.assertEqual(repr(self.node2), "<Node: (DEFAULT: )>")

    def test_len(self):
        self.assertEqual(len(self.node1), 2)
        self.assertEqual(len(self.node2), 0)

    def test_bool(self):
        self.assertTrue(self.node1)
        self.assertFalse(self.node2)

    def test_contains(self):
        self.assertIn(('a', 1), self.node1)
        self.assertNotIn(('a', 1), self.node2)

    def test_add(self):
        # start with the same children of node1 then add an item
        node3 = Node(self.node1_children)
        node3_added_child = ('c', 3)
        # add() returns the added data
        self.assertEqual(node3.add(node3_added_child, Node.default),
                         node3_added_child)
        # we added exactly one item, len() should reflect that
        self.assertEqual(len(self.node1) + 1, len(node3))
        self.assertEqual(str(node3), "(DEFAULT: ('a', 1), ('b', 2), ('c', 3))")

    def test_negate(self):
        # negated is False by default
        self.assertFalse(self.node1.negated)
        self.node1.negate()
        self.assertTrue(self.node1.negated)
        self.node1.negate()
        self.assertFalse(self.node1.negated)

    def test_deepcopy(self):
        node4 = copy.copy(self.node1)
        node5 = copy.deepcopy(self.node1)
        self.assertIs(self.node1.children, node4.children)
        self.assertIsNot(self.node1.children, node5.children)
Esempio n. 7
0
 def clone(self):
     obj = self.__class__(self.model)
     obj.where = Node()
     obj.where = deepcopy(self.where)
     obj.order_by = deepcopy(self.order_by)
     obj.fields = deepcopy(self.fields)
     obj.target_collection_prefix = self.target_collection_prefix
     return obj
Esempio n. 8
0
 def __init__(self, model):
     self.model = model
     self.where = Node()
     self.order_by = []
     self.fields = set()
     self.default_ordering = True
     self.select_related = True #Hack for admin-interface: django/contrib/admin/views/main.py line 210
     self._field_names = {}
     self.target_collection_prefix=None
Esempio n. 9
0
 def need_having(self, obj):  # pragma: no cover
     # This method is used by older Django versions to figure out if the
     # filter represented by a Q object must be put in the HAVING clause of
     # the query. Since a queryable property might add an aggregate-based
     # annotation during the actual filter application, this method must
     # return True if a filter condition contains such a property.
     node = obj if isinstance(obj, Node) else Node([obj])
     if aggregate_property_checker.check_leaves(node, model=self.model):
         return True
     # The base method has different names in different Django versions (see
     # comment on the constant definition).
     base_method = getattr(super(QueryablePropertiesQueryMixin, self),
                           NEED_HAVING_METHOD_NAME)
     return base_method(obj)
Esempio n. 10
0
 def test_hash(self):
     node3 = Node(self.node1_children, negated=True)
     node4 = Node(self.node1_children, connector="OTHER")
     node5 = Node(self.node1_children)
     node6 = Node([["a", 1], ["b", 2]])
     node7 = Node([("a", [1, 2])])
     node8 = Node([("a", (1, 2))])
     self.assertNotEqual(hash(self.node1), hash(self.node2))
     self.assertNotEqual(hash(self.node1), hash(node3))
     self.assertNotEqual(hash(self.node1), hash(node4))
     self.assertEqual(hash(self.node1), hash(node5))
     self.assertEqual(hash(self.node1), hash(node6))
     self.assertEqual(hash(self.node2), hash(Node()))
     self.assertEqual(hash(node7), hash(node8))
Esempio n. 11
0
 def test_hash(self):
     node3 = Node(self.node1_children, negated=True)
     node4 = Node(self.node1_children, connector='OTHER')
     node5 = Node(self.node1_children)
     node6 = Node([['a', 1], ['b', 2]])
     node7 = Node([('a', [1, 2])])
     node8 = Node([('a', (1, 2))])
     self.assertNotEqual(hash(self.node1), hash(self.node2))
     self.assertNotEqual(hash(self.node1), hash(node3))
     self.assertNotEqual(hash(self.node1), hash(node4))
     self.assertEqual(hash(self.node1), hash(node5))
     self.assertEqual(hash(self.node1), hash(node6))
     self.assertEqual(hash(self.node2), hash(Node()))
     self.assertEqual(hash(node7), hash(node8))
Esempio n. 12
0
 def test_eq_negated(self):
     node = Node(negated=False)
     negated = Node(negated=True)
     self.assertNotEqual(negated, node)
Esempio n. 13
0
 def test_eq_connector(self):
     new_node = Node(connector="NEW")
     default_node = Node(connector="DEFAULT")
     self.assertEqual(default_node, self.node2)
     self.assertNotEqual(default_node, new_node)
Esempio n. 14
0
 def test_eq_children(self):
     node = Node(self.node1_children)
     self.assertEqual(node, self.node1)
     self.assertNotEqual(node, self.node2)
Esempio n. 15
0
class MongoQuery(object):
    AND = 'AND'
    OR = 'OR'
    COLON = ':'
    DEFAULT = 'DEFAULT'

    def __init__(self, model):
        self.model = model
        self.where = Node()
        self.order_by = []
        self.fields = set()
        self.default_ordering = True
        self.select_related = True #Hack for admin-interface: django/contrib/admin/views/main.py line 210
        self._field_names = {}
        self.target_collection_prefix=None

    def _setTargetCollectionPrefix(self, prefix_str=None):
        """
        This method enables redirecting the entire model onto another collection.
        Note: this should be used VERY carefully as it will redirect all reading and writing.
        Use prefix_str=None to cancel the redirection.
        E.g. if prefix_str='acme_' and the db name is 'customers', a new db will be used, called 'acme_customers'.
        """
        self.target_collection_prefix=prefix_str

    def _column_name(self, field_name):
        if isinstance(field_name, (list, tuple)):
            return [self._column_name(f) for f in field_name]
        
        if not field_name in self._field_names:
            opt = self.model._meta

            if field_name == 'pk':
                field = opt.pk
            else:
                field = opt.get_field(field_name)
            
            self._field_names[field_name] = field.get_attname_column()[1]
        
        return self._field_names[field_name]

    def clear_ordering(self, force_empty=False):
        self.order_by = []
        if force_empty:
            self.default_ordering = False
        
    def add_ordering(self, ordering):
        if ordering:
            self.order_by.extend(ordering)
        else:
            self.default_ordering = False
        
    def can_filter(self):
        """
        For SQL returns False is result already fetched
        """
        return True
    
    def set_fields(self, fields):
        self.fields.update(set(fields))
    
    def get_fields(self):
        return list(self.fields)
    
    def clone(self):
        obj = self.__class__(self.model)
        obj.where = Node()
        obj.where = deepcopy(self.where)
        obj.order_by = deepcopy(self.order_by)
        obj.fields = deepcopy(self.fields)
        obj.target_collection_prefix = self.target_collection_prefix
        return obj
    
    def add_q(self, q_object):
        """
        Adds a Q-object to the current filter.

        Can also be used to add anything that has an 'add_to_query()' method.
        See also the method / property 'spec', that operates on the tree structure created in this method,
            kept in self.where
        """
        if hasattr(q_object, 'add_to_query'):
            # Complex custom objects are responsible for adding themselves.
            q_object.add_to_query(self)
        else:
            if self.where and q_object.connector != self.AND and len(q_object) > 1:
                self.where.start_subtree(self.AND)
                subtree = True
            else:
                subtree = False
            connector = self.AND
            for child in q_object.children:
                self.where.start_subtree(connector)
                if isinstance(child, Node):
                    self.add_q(child)
                else:
                    mq = MongoQ(False, child[0],child[1])
                    self.where.add(mq,self.COLON)
                self.where.end_subtree()
                connector = q_object.connector
            if q_object.negated:
                self.where.negate()
            if subtree:
                self.where.end_subtree()

    def get_count(self):
        return self.get_result().count()
    
    def has_results(self):
        return bool(self.get_count())
    
    def get_result(self):
        cursor = self.collection.find(self.spec, fields=self.fields and self._column_name(list(self.fields)) or None)

        sort = self.sort
        
        if sort:
            cursor = cursor.sort(sort)
        
        return cursor
    
    def update(self, values):
        data = {}
        for key, val in values.items():
            data[self._column_name(key)] = fix_value(val)

        self.collection.update(self.spec, {"$set": data})
    
    def delete(self):
        self.collection.remove(self.spec)
    
    @property
    def sort(self):
        output = []
        
        default_ordering = self.model._meta.ordering
        
        if not self.order_by and self.default_ordering and default_ordering:
            order_by = default_ordering
        else:
            order_by = self.order_by
        
        for item in order_by:
            if item.startswith('-'):
                output.append((self._column_name(item[1:]), pymongo.DESCENDING))
            else:
                output.append((self._column_name(item), pymongo.ASCENDING))
                
        return output

    def _process_where_node(self, oq, spec):
        if oq.connector != self.DEFAULT:
            if oq.connector == self.COLON:
                q = oq.children[0]
                column_orig = self._column_name(q._get_field_first_part())
                column = q.getDotNotationFieldName(self.model)
                qspec = q.spec(self.model)

                if not column in spec:
                    spec[column] = qspec
                else:
                    filter = spec[column]
                    if not isinstance(spec[column], dict) or not isinstance(qspec, dict):
                        raise MongoQuerySetExeption('Does not support = and <,>,=<,=> in one query (column:%s)' % column)
                    spec[column].update(qspec)
            elif oq.connector == self.AND:
                r = {}
                spec1 = {}
                for a_item in oq.children:
                    if a_item.connector == self.DEFAULT and len(a_item.children) == 1:
                        a_item = a_item.children[0]
                    if a_item.connector != self.DEFAULT:
                        r.update(self._process_where_node(a_item,spec1))
                spec = r
                if oq.negated:
                    spec = negate_where_clause(spec)
            elif oq.connector == self.OR:
                r = []
                for a_item in oq.children:
                    if a_item.connector != self.DEFAULT:
                        spec1 = {}
                        r.append(self._process_where_node(a_item,spec1))
                spec = self._or_operation_optimization({'$or': r})
                if oq.negated:
                    spec = negate_where_clause(spec)
        return spec


    def _or_operation_optimization(self, spec):
        """
        This method creates a very simple optimization to the case where there is a complete $or where statement,
        where all the items are in a form that can be replaced with in $in statement.
        For example, the following can be optimized:
          {'$or': [{'time_on_site': 10.0}, {'time_on_site': 25.0}, {'time_on_site': 275.0}, {'time_on_site': 148.0}]}
        However, the following CAN NOT be optimized:
          {'$or': [{'time_on_site': 10.0}, {'time_on_site': 25.0}, {'time_on_site': 275.0}, {'visit_count': 12.0}]}
          (because not all parameters are identical)
        Returns the original spec if can't be optimized, or a new $in where statement in the format:
          { x : { $in : [ a, b ] } }
        See also http://www.mongodb.org/display/DOCS/OR+operations+in+query+expressions for more information
        """
        param = None
        values = []
        for item in spec['$or']:
            p = item.keys()
            if len(p)!=1:
                return spec
            if param is None:
                param = p[0]
            if param != p[0]:
                return spec
            if isinstance(item[param], dict):
                return spec
            values.append(item[param])

        res = {param: {'$in': values}}
        return res


    @property
    def spec(self):
        if not self.where:
            return None
        
        spec = {}
        
#        for oq in self.where.children:
#            spec = self._process_where_node(oq, spec)
        spec = self._process_where_node(self.where, spec)

        return spec
    
    @property
    def collection(self):
        return get_collection(self.model,self.target_collection_prefix)
    
    def _prepare_before_insert(self, doc):
        if isinstance(doc, (list, tuple)):
            return [self._prepare_before_insert(d) for d in doc]
        
        data = {}
        
        for name, value in doc.items():
            data[self._column_name(name)] = value

        return data
    
    def insert(self, docs):
        return self.collection.insert(self._prepare_before_insert(docs), safe=True)
Esempio n. 16
0
 def setUp(self):
     self.node1_children = [("a", 1), ("b", 2)]
     self.node1 = Node(self.node1_children)
     self.node2 = Node()
Esempio n. 17
0
 def setUp(self):
     self.node1_children = [('a', 1), ('b', 2)]
     self.node1 = Node(self.node1_children)
     self.node2 = Node()
Esempio n. 18
0
class MongoQuery(object):
    AND = 'AND'
    OR = 'OR'
    COLON = ':'
    DEFAULT = 'DEFAULT'

    def __init__(self, model):
        self.model = model
        self.where = Node()
        self.order_by = []
        self.fields = set()
        self.default_ordering = True
        self.select_related = True  #Hack for admin-interface: django/contrib/admin/views/main.py line 210
        self._field_names = {}
        self.target_collection_prefix = None

    def _setTargetCollectionPrefix(self, prefix_str=None):
        """
        This method enables redirecting the entire model onto another collection.
        Note: this should be used VERY carefully as it will redirect all reading and writing.
        Use prefix_str=None to cancel the redirection.
        E.g. if prefix_str='acme_' and the db name is 'customers', a new db will be used, called 'acme_customers'.
        """
        self.target_collection_prefix = prefix_str

    def _column_name(self, field_name):
        if isinstance(field_name, (list, tuple)):
            return [self._column_name(f) for f in field_name]

        if not field_name in self._field_names:
            opt = self.model._meta

            if field_name == 'pk':
                field = opt.pk
            else:
                field = opt.get_field(field_name)

            self._field_names[field_name] = field.get_attname_column()[1]

        return self._field_names[field_name]

    def clear_ordering(self, force_empty=False):
        self.order_by = []
        if force_empty:
            self.default_ordering = False

    def add_ordering(self, ordering):
        if ordering:
            self.order_by.extend(ordering)
        else:
            self.default_ordering = False

    def can_filter(self):
        """
        For SQL returns False is result already fetched
        """
        return True

    def set_fields(self, fields):
        self.fields.update(set(fields))

    def get_fields(self):
        return list(self.fields)

    def clone(self):
        obj = self.__class__(self.model)
        obj.where = Node()
        obj.where = deepcopy(self.where)
        obj.order_by = deepcopy(self.order_by)
        obj.fields = deepcopy(self.fields)
        obj.target_collection_prefix = self.target_collection_prefix
        return obj

    def add_q(self, q_object):
        """
        Adds a Q-object to the current filter.

        Can also be used to add anything that has an 'add_to_query()' method.
        See also the method / property 'spec', that operates on the tree structure created in this method,
            kept in self.where
        """
        if hasattr(q_object, 'add_to_query'):
            # Complex custom objects are responsible for adding themselves.
            q_object.add_to_query(self)
        else:
            if self.where and q_object.connector != self.AND and len(
                    q_object) > 1:
                self.where.start_subtree(self.AND)
                subtree = True
            else:
                subtree = False
            connector = self.AND
            for child in q_object.children:
                self.where.start_subtree(connector)
                if isinstance(child, Node):
                    self.add_q(child)
                else:
                    mq = MongoQ(False, child[0], child[1])
                    self.where.add(mq, self.COLON)
                self.where.end_subtree()
                connector = q_object.connector
            if q_object.negated:
                self.where.negate()
            if subtree:
                self.where.end_subtree()

    def get_count(self):
        return self.get_result().count()

    def has_results(self):
        return bool(self.get_count())

    def get_result(self):
        cursor = self.collection.find(
            self.spec,
            fields=self.fields and self._column_name(list(self.fields))
            or None)

        sort = self.sort

        if sort:
            cursor = cursor.sort(sort)

        return cursor

    def update(self, values):
        data = {}
        for key, val in values.items():
            data[self._column_name(key)] = fix_value(val)

        self.collection.update(self.spec, {"$set": data})

    def delete(self):
        self.collection.remove(self.spec)

    @property
    def sort(self):
        output = []

        default_ordering = self.model._meta.ordering

        if not self.order_by and self.default_ordering and default_ordering:
            order_by = default_ordering
        else:
            order_by = self.order_by

        for item in order_by:
            if item.startswith('-'):
                output.append(
                    (self._column_name(item[1:]), pymongo.DESCENDING))
            else:
                output.append((self._column_name(item), pymongo.ASCENDING))

        return output

    def _process_where_node(self, oq, spec):
        if oq.connector != self.DEFAULT:
            if oq.connector == self.COLON:
                q = oq.children[0]
                column_orig = self._column_name(q._get_field_first_part())
                column = q.getDotNotationFieldName(self.model)
                qspec = q.spec(self.model)

                if not column in spec:
                    spec[column] = qspec
                else:
                    filter = spec[column]
                    if not isinstance(spec[column], dict) or not isinstance(
                            qspec, dict):
                        raise MongoQuerySetExeption(
                            'Does not support = and <,>,=<,=> in one query (column:%s)'
                            % column)
                    spec[column].update(qspec)
            elif oq.connector == self.AND:
                r = {}
                spec1 = {}
                for a_item in oq.children:
                    if a_item.connector == self.DEFAULT and len(
                            a_item.children) == 1:
                        a_item = a_item.children[0]
                    if a_item.connector != self.DEFAULT:
                        r.update(self._process_where_node(a_item, spec1))
                spec = r
                if oq.negated:
                    spec = negate_where_clause(spec)
            elif oq.connector == self.OR:
                r = []
                for a_item in oq.children:
                    if a_item.connector != self.DEFAULT:
                        spec1 = {}
                        r.append(self._process_where_node(a_item, spec1))
                spec = self._or_operation_optimization({'$or': r})
                if oq.negated:
                    spec = negate_where_clause(spec)
        return spec

    def _or_operation_optimization(self, spec):
        """
        This method creates a very simple optimization to the case where there is a complete $or where statement,
        where all the items are in a form that can be replaced with in $in statement.
        For example, the following can be optimized:
          {'$or': [{'time_on_site': 10.0}, {'time_on_site': 25.0}, {'time_on_site': 275.0}, {'time_on_site': 148.0}]}
        However, the following CAN NOT be optimized:
          {'$or': [{'time_on_site': 10.0}, {'time_on_site': 25.0}, {'time_on_site': 275.0}, {'visit_count': 12.0}]}
          (because not all parameters are identical)
        Returns the original spec if can't be optimized, or a new $in where statement in the format:
          { x : { $in : [ a, b ] } }
        See also http://www.mongodb.org/display/DOCS/OR+operations+in+query+expressions for more information
        """
        param = None
        values = []
        for item in spec['$or']:
            p = item.keys()
            if len(p) != 1:
                return spec
            if param is None:
                param = p[0]
            if param != p[0]:
                return spec
            if isinstance(item[param], dict):
                return spec
            values.append(item[param])

        res = {param: {'$in': values}}
        return res

    @property
    def spec(self):
        if not self.where:
            return None

        spec = {}

        #        for oq in self.where.children:
        #            spec = self._process_where_node(oq, spec)
        spec = self._process_where_node(self.where, spec)

        return spec

    @property
    def collection(self):
        return get_collection(self.model, self.target_collection_prefix)

    def _prepare_before_insert(self, doc):
        if isinstance(doc, (list, tuple)):
            return [self._prepare_before_insert(d) for d in doc]

        data = {}

        for name, value in doc.items():
            data[self._column_name(name)] = value

        return data

    def insert(self, docs):
        return self.collection.insert(self._prepare_before_insert(docs),
                                      safe=True)
Esempio n. 19
0
 def test_add_eq_child_mixed_connector(self):
     node = Node(['a', 'b'], 'OR')
     self.assertEqual(node.add('a', 'AND'), 'a')
     self.assertEqual(node, Node([Node(['a', 'b'], 'OR'), 'a'], 'AND'))
Esempio n. 20
0
def shallow_copy_Q(q_object):
    obj = Node(connector=q_object.connector, negated=q_object.negated)
    obj.__class__ = q_object.__class__
    obj.children = q_object.children
    return obj
Esempio n. 21
0
 def test_add_eq_child_mixed_connector(self):
     node = Node(["a", "b"], OR)
     self.assertEqual(node.add("a", AND), "a")
     self.assertEqual(node, Node([Node(["a", "b"], OR), "a"], AND))
Esempio n. 22
0
 def test_eq_connector(self):
     new_node = Node(connector='NEW')
     default_node = Node(connector='DEFAULT')
     self.assertEqual(default_node, self.node2)
     self.assertNotEqual(default_node, new_node)
Esempio n. 23
0
class NodeTests(unittest.TestCase):
    def setUp(self):
        self.node1_children = [("a", 1), ("b", 2)]
        self.node1 = Node(self.node1_children)
        self.node2 = Node()

    def test_str(self):
        self.assertEqual(str(self.node1), "(DEFAULT: ('a', 1), ('b', 2))")
        self.assertEqual(str(self.node2), "(DEFAULT: )")

    def test_repr(self):
        self.assertEqual(repr(self.node1), "<Node: (DEFAULT: ('a', 1), ('b', 2))>")
        self.assertEqual(repr(self.node2), "<Node: (DEFAULT: )>")

    def test_hash(self):
        node3 = Node(self.node1_children, negated=True)
        node4 = Node(self.node1_children, connector="OTHER")
        node5 = Node(self.node1_children)
        node6 = Node([["a", 1], ["b", 2]])
        node7 = Node([("a", [1, 2])])
        node8 = Node([("a", (1, 2))])
        self.assertNotEqual(hash(self.node1), hash(self.node2))
        self.assertNotEqual(hash(self.node1), hash(node3))
        self.assertNotEqual(hash(self.node1), hash(node4))
        self.assertEqual(hash(self.node1), hash(node5))
        self.assertEqual(hash(self.node1), hash(node6))
        self.assertEqual(hash(self.node2), hash(Node()))
        self.assertEqual(hash(node7), hash(node8))

    def test_len(self):
        self.assertEqual(len(self.node1), 2)
        self.assertEqual(len(self.node2), 0)

    def test_bool(self):
        self.assertTrue(self.node1)
        self.assertFalse(self.node2)

    def test_contains(self):
        self.assertIn(("a", 1), self.node1)
        self.assertNotIn(("a", 1), self.node2)

    def test_add(self):
        # start with the same children of node1 then add an item
        node3 = Node(self.node1_children)
        node3_added_child = ("c", 3)
        # add() returns the added data
        self.assertEqual(node3.add(node3_added_child, Node.default), node3_added_child)
        # we added exactly one item, len() should reflect that
        self.assertEqual(len(self.node1) + 1, len(node3))
        self.assertEqual(str(node3), "(DEFAULT: ('a', 1), ('b', 2), ('c', 3))")

    def test_add_eq_child_mixed_connector(self):
        node = Node(["a", "b"], OR)
        self.assertEqual(node.add("a", AND), "a")
        self.assertEqual(node, Node([Node(["a", "b"], OR), "a"], AND))

    def test_negate(self):
        # negated is False by default
        self.assertFalse(self.node1.negated)
        self.node1.negate()
        self.assertTrue(self.node1.negated)
        self.node1.negate()
        self.assertFalse(self.node1.negated)

    def test_create(self):
        SubNode = type("SubNode", (Node,), {})

        a = SubNode([SubNode(["a", "b"], OR), "c"], AND)
        b = SubNode.create(a.children, a.connector, a.negated)
        self.assertEqual(a, b)
        # Children lists are the same object, but equal.
        self.assertIsNot(a.children, b.children)
        self.assertEqual(a.children, b.children)
        # Child Node objects are the same objects.
        for a_child, b_child in zip(a.children, b.children):
            if isinstance(a_child, Node):
                self.assertIs(a_child, b_child)
            self.assertEqual(a_child, b_child)

    def test_copy(self):
        a = Node([Node(["a", "b"], OR), "c"], AND)
        b = copy.copy(a)
        self.assertEqual(a, b)
        # Children lists are the same object.
        self.assertIs(a.children, b.children)
        # Child Node objects are the same objects.
        for a_child, b_child in zip(a.children, b.children):
            if isinstance(a_child, Node):
                self.assertIs(a_child, b_child)
            self.assertEqual(a_child, b_child)

    def test_deepcopy(self):
        a = Node([Node(["a", "b"], OR), "c"], AND)
        b = copy.deepcopy(a)
        self.assertEqual(a, b)
        # Children lists are not be the same object, but equal.
        self.assertIsNot(a.children, b.children)
        self.assertEqual(a.children, b.children)
        # Child Node objects are not be the same objects.
        for a_child, b_child in zip(a.children, b.children):
            if isinstance(a_child, Node):
                self.assertIsNot(a_child, b_child)
            self.assertEqual(a_child, b_child)

    def test_eq_children(self):
        node = Node(self.node1_children)
        self.assertEqual(node, self.node1)
        self.assertNotEqual(node, self.node2)

    def test_eq_connector(self):
        new_node = Node(connector="NEW")
        default_node = Node(connector="DEFAULT")
        self.assertEqual(default_node, self.node2)
        self.assertNotEqual(default_node, new_node)

    def test_eq_negated(self):
        node = Node(negated=False)
        negated = Node(negated=True)
        self.assertNotEqual(negated, node)
def shallow_copy_Q(q_object):
    obj = Node(connector=q_object.connector, negated=q_object.negated)
    obj.__class__ = q_object.__class__
    obj.children = q_object.children
    return obj
Esempio n. 25
0
class NodeTests(unittest.TestCase):
    def setUp(self):
        self.node1_children = [('a', 1), ('b', 2)]
        self.node1 = Node(self.node1_children)
        self.node2 = Node()

    def test_str(self):
        self.assertEqual(str(self.node1), "(DEFAULT: ('a', 1), ('b', 2))")
        self.assertEqual(str(self.node2), "(DEFAULT: )")

    def test_repr(self):
        self.assertEqual(repr(self.node1),
                         "<Node: (DEFAULT: ('a', 1), ('b', 2))>")
        self.assertEqual(repr(self.node2), "<Node: (DEFAULT: )>")

    def test_hash(self):
        node3 = Node(self.node1_children, negated=True)
        node4 = Node(self.node1_children, connector='OTHER')
        node5 = Node(self.node1_children)
        node6 = Node([['a', 1], ['b', 2]])
        node7 = Node([('a', [1, 2])])
        node8 = Node([('a', (1, 2))])
        self.assertNotEqual(hash(self.node1), hash(self.node2))
        self.assertNotEqual(hash(self.node1), hash(node3))
        self.assertNotEqual(hash(self.node1), hash(node4))
        self.assertEqual(hash(self.node1), hash(node5))
        self.assertEqual(hash(self.node1), hash(node6))
        self.assertEqual(hash(self.node2), hash(Node()))
        self.assertEqual(hash(node7), hash(node8))

    def test_len(self):
        self.assertEqual(len(self.node1), 2)
        self.assertEqual(len(self.node2), 0)

    def test_bool(self):
        self.assertTrue(self.node1)
        self.assertFalse(self.node2)

    def test_contains(self):
        self.assertIn(('a', 1), self.node1)
        self.assertNotIn(('a', 1), self.node2)

    def test_add(self):
        # start with the same children of node1 then add an item
        node3 = Node(self.node1_children)
        node3_added_child = ('c', 3)
        # add() returns the added data
        self.assertEqual(node3.add(node3_added_child, Node.default),
                         node3_added_child)
        # we added exactly one item, len() should reflect that
        self.assertEqual(len(self.node1) + 1, len(node3))
        self.assertEqual(str(node3), "(DEFAULT: ('a', 1), ('b', 2), ('c', 3))")

    def test_negate(self):
        # negated is False by default
        self.assertFalse(self.node1.negated)
        self.node1.negate()
        self.assertTrue(self.node1.negated)
        self.node1.negate()
        self.assertFalse(self.node1.negated)

    def test_deepcopy(self):
        node4 = copy.copy(self.node1)
        node5 = copy.deepcopy(self.node1)
        self.assertIs(self.node1.children, node4.children)
        self.assertIsNot(self.node1.children, node5.children)

    def test_eq_children(self):
        node = Node(self.node1_children)
        self.assertEqual(node, self.node1)
        self.assertNotEqual(node, self.node2)

    def test_eq_connector(self):
        new_node = Node(connector='NEW')
        default_node = Node(connector='DEFAULT')
        self.assertEqual(default_node, self.node2)
        self.assertNotEqual(default_node, new_node)

    def test_eq_negated(self):
        node = Node(negated=False)
        negated = Node(negated=True)
        self.assertNotEqual(negated, node)
 def negate (self) :
     self.children = [Node(self.children, self.connector, not self.negated)]
     self.connector = self.default