def testMoreFunctions(self):
        def test1(val):
            return val+1
        def test2(v1,v2):
            return v1+v2

        valid_replacements = (
            ('abs(',   'abs123(', {'abs123':test1}),
            ('round(', 'abs321(', {'abs321':test2}),
            ('',       '',        {'abs':test1}),
            ('',       '',        {'abs':test1, 'round':test2}),
            )

        create = self.complete_create
        for old_token, new_token, functions in valid_replacements:
            parse_view_spec( create.replace(old_token,new_token), functions )

        self.assertRaises(UnknownReferenceError,
                          parse_view_spec, create.replace('abs','abs1'),
                          {'abs2':test1})
    def testLists(self):
        where = 'WHERE node.a5 in ($list)'
        where_statement = Template(TestStatements.t_view_def.safe_substitute(
            where=where, qos='', loops='', having=''))

        loop = 'FOREACH i IN ($list)'
        loop_statement = Template(TestStatements.t_view_def.safe_substitute(
            where='', qos='', loops=loop, having=''))

        valid_lists = [
            '1',
            '1:5',
            '"1"',
            '1,2,3,5,6,24',
            '1,2,3,5:6,24',
            '1.5,6.5,7.6',
            '"1", "2", "3","4"',
            'True, False, True, True, False',
            ]

        invalid_lists = [
            '1,'
            '1,"1"',
            '1,2,3,5.5:6,7',
            '1.2,5.5:6.1,7.2',
            '"1", "2", "3","4":"10"',
            '1, True',
            ]

        for statement in (where_statement, loop_statement):
            for list in valid_lists:
                parse_view_spec( statement.substitute(list=list) )
            for list in invalid_lists:
                self.assertRaises(ParseException,
                                  parse_view_spec,
                                  statement.substitute(list=list))
    def testValidStatements(self):
        valid_replacements = (
            ('CREATE VIEW view', 'CREATE VIEW new_view2'),
            #('SELECT node', 'SELECT message, node'), # -> ParseError
            ('WHERE ', 'WHERE not '),
            (' == ', ' eq '),
            (' == ', ' <> '),
            (' < ', ' lt '),
            (' and ', ' and not '),
            (' or ', ' or not '),
            ('abs(', 'tan('),
            ('abs(', 'abs(var1+'),
            ('lowest(', 'highest('),
            ('lowest(', 'closest(3/4^2,'),
            ('"test"', '"test\'s"'),
            ('10 m', '95 d'),
            ('ON (node.a5)', 'ON (node.a5, node.a2)'),
            ('myconfval', 'myconfval=1*54'),
            ('old_view', 'old1, old2, old3'),
            )

        create = self.complete_create
        for old_token, new_token in valid_replacements:
            parse_view_spec(create.replace(old_token,new_token))
    def testParseBuckets(self):
        statements = []
        for cr in self.create_buckets:
            try:
                statements.append(parse_view_spec(cr))
            except ParseException:
                print >>sys.stderr, cr
                raise

        self.assertEqual(True, statements[0].where_expression.evaluate())

        for i in range(1,len(statements)):
            self.assertEqual(1, len(statements[i].object_select))

            self.assertEqual(statements[0].view_object.name,
                             statements[i].view_object.name)
    def testViewCreation(self):
        "Test: creating views"

        view_def = '''
        CREATE VIEW view_name
        AS SELECT node.a1 = node.a3 - node.a1,
                  node.a2 = node.a1 + node.a2,
                  node.a3 = 2 * node.a5 + node.a3,
                  node.a5,
                  node.a_new = 75
        RANKED lowest(2*5, node.a3*(node.a5^2)+1)
        FROM db
        '''

        where_clause = '''
        WHERE node.a3
           or (1+2*round(node.a1,2) = .1E5+4.E-4)
           or node.a1 == 1 and node.a2 == "test"
           or 1+node.a1 <> 4
        '''

        create       = view_def + ';'
        create_where = view_def + where_clause + ';'

        attr_names = [ 'a%s' % i for i in range(1,10) ]

        db = self.db
        for i in range(9):
            db.addAttribute(attr_names[i], int)

        for a5 in range(-100,100):
            node = DBNode(db)
            for i in range(9):
                setattr(node, attr_names[i], i+1)
            node.a5 = a5
            db.addNode(node)

        sview0 = NodeView( parse_view_spec(create.replace('view_name', 'sview0')), self.viewreg )
        self.assertEqual(2*5, len(sview0.getBucket()))

        sview1 = NodeView( parse_view_spec(create_where.replace('view_name', 'sview1')), self.viewreg )
        self.assertEqual(2*5, len(sview1.getBucket()))

        sview0_bucket = sview0.getBucket()
        for node in sview1.getBucket():
            self.assert_(node in sview0_bucket)

        create_subview = create_where.replace('FROM db', 'FROM sview0')
        sview2 = NodeView( parse_view_spec(create_subview.replace('view_name', 'sview2')), self.viewreg )
        self.assertEqual(2*5, len(sview2.getBucket()))

        create_mergeview = create_where.replace('2*5','2*10').replace('FROM db', 'FROM sview0, sview1')
        sview3 = NodeView( parse_view_spec(create_mergeview.replace('view_name', 'sview3')), self.viewreg )

        # both parent nodes are equal, merging them should filter duplicates
        self.assertEqual(2*5, len(sview3.getBucket()))

##         # merging parents without DISTINCT adds duplicates
##         create_mergeview_dup = create_mergeview.replace('DISTINCT','')
##         sview3_dup = NodeView( parse_view_spec(create_mergeview_dup.replace('view_name', 'sview3_dup')), self.viewreg )

##         self.assertEqual(2*10, len(sview3_dup.getBucket()))

        # verify that nodes only appear once
        nodes = []
        for node in sview3.getBucket():
            self.assert_(node not in nodes)
            nodes.append(node)

        # retry inheriting from all known views
        for i in range(10):
            list_of_known_views = ','.join(view_name for view_name in self.viewreg)
            create_mergeview = create_where.replace('2*5', '2*20').replace('FROM db', 'FROM '+list_of_known_views)
            sview4 = NodeView( parse_view_spec(create_mergeview.replace('view_name', 'sview%02d'%i)), self.viewreg )

        # verify that nodes only appear once
        nodes = []
        for node in sview3.getBucket():
            self.assert_(node not in nodes)
            nodes.append(node)

        nodes_a5 = set()
        for node in sview4.getBucket():
            self.assert_(node.a5 not in nodes_a5)
            nodes_a5.add(node.a5)
    def testViewCircle(self):
        """Test: ring topology
        Each node has one or more neighbours in each direction.
        """

        node_count = self.NODE_COUNT

        db = self.db
        db.addAttribute('id', int)
        nodes = [ DBNode(db, id=n) for n in range(node_count) ]

        self.db.addNodes(nodes)

        circle_view_def = Template('''
        CREATE VIEW circle_neighbours_${id}_$cmp
        AS SELECT node.id
        RANKED lowest(ncount, dist($id, node.id))
        FROM db
        WITH ncount=1
        WHERE abs(node.id - $id) <= $maxid / 2  and     node.id $cmp $id
           or abs(node.id - $id) >  $maxid / 2  and not node.id $cmp $id
        ;
        ''')

        def ring_dist(n1, n2):
            return min(abs(n1 - n2), max_id - abs(n1 - n2))

        func_def = {'dist':ring_dist}
        max_id   = node_count

        viewreg = self.viewreg
        for node in nodes:
            view_spec_lt = circle_view_def.substitute(maxid=max_id, id=node.id, cmp='lt')
            view_spec_gt = circle_view_def.substitute(maxid=max_id, id=node.id, cmp='gt')

            spec_lt = parse_view_spec(view_spec_lt, func_def)
            spec_gt = parse_view_spec(view_spec_gt, func_def)

            node._lt = NodeView(spec_lt, viewreg)
            node._gt = NodeView(spec_gt, viewreg)

        for ncount in (1,4):
            for node in nodes:
                node._lt.setVariable('ncount', ncount)
                node._gt.setVariable('ncount', ncount)

            for node in nodes:
                self.assertEqual(ncount, len(node._lt.getBucket()))
                self.assertEqual(ncount, len(node._gt.getBucket()))

            # test if nodes know their neighbours
            node_count = len(nodes)
            neighbour_iterators = [ islice(niter, i, node_count+i)
                                    for (i, niter) in enumerate(tee(chain(nodes,nodes,nodes), ncount+1)) ]

            for neighbours in izip(*neighbour_iterators):
                node = neighbours[0]
                for neighbour in neighbours[1:]:
                    self.assert_(neighbour     in node._gt.getBucket())
                    self.assert_(neighbour not in node._lt.getBucket())
                    self.assert_(node      not in neighbour._gt.getBucket())
                    self.assert_(node          in neighbour._lt.getBucket())

            if DRAW_GRAPHS:
                ViewGraph.write_files( [ (node, (node._lt, node._gt))
                                         for node in nodes ],
                                       program='circo', fname='circle%02d'%ncount )

            if SHOW_VISUAL3D:
                VisualViewGraph( [ (node, (node._lt, node._gt))
                                   for node in nodes ],
                                 program='circo', graph_name='Circle%02d'%ncount )

            self.viewreg.unregister_all()
    def testViewVariables(self):
        "Test: variable usage in views"

        view_def = '''
        CREATE VIEW view_name
        AS SELECT node.a1 = (node.a3 - node.a1) * var_a,
                  node.a2 = node.a1 + node.a2,
                  node.a3 = 2 * node.a5 + node.a3,
                  node.a5,
                  node.a_new = var_a
        RANKED lowest(var_a*2, var_a*node.a3*(node.a5^2)+1)
        FROM db
        WITH var_a = 5
        '''

        where_clause = '''
        WHERE node.a3
           or (1+2*round(node.a1,2) = .1E5+4.E-4)
           or node.a1 == 1 and node.a2 == "test"
           or 1+node.a1 <> 4
        '''

        create       = view_def + ';'
        create_where = view_def + where_clause + ';'

        attr_names = [ 'a%s' % i for i in range(1,10) ]

        db = self.db
        for i in range(9):
            db.addAttribute(attr_names[i], int)

        for a5 in range(-100,100):
            node = DBNode(db)
            for i in range(9):
                setattr(node, attr_names[i], i+1)
            node.a5 = a5
            db.addNode(node)


        sview0 = NodeView( parse_view_spec(create.replace('view_name', 'sview0')), self.viewreg )

        self.assertEqual(len(sview0.getBucket()), len(sview0))

        some_node = sview0.getBucket()[0]
        a1_of_some_node = some_node.a1

        self.assertEqual(2*5, len(sview0))
        sview0.setVariable('var_a', 8)
        self.assertEqual(2*8, len(sview0))

        some_node = sview0.getBucket()[0]
        self.assertNotEqual(a1_of_some_node, some_node.a1)

        sview1 = NodeView( parse_view_spec(create_where.replace('view_name', 'sview1')), self.viewreg )

        self.assertEqual(2* 5, len(sview1))
        sview1.setVariable('var_a', 16)
        self.assertEqual(2*16, len(sview1))
        self.assertEqual(2* 8, len(sview0))

        create_subview = create_where.replace('FROM db', 'FROM sview0')
        sview2 = NodeView( parse_view_spec(create_subview.replace('view_name', 'sview2')), self.viewreg )

        sview2.setVariable('var_a', 1000)
        self.assertEqual(2*8,  len(sview0))
        self.assertEqual(2*8,  len(sview2))
        self.assertEqual(2*16, len(sview1))