class TestIndexAndBulk(object):

    def setup(self):
        from grapheekdb.backends.data.localmem import LocalMemoryGraph
        self.graph = LocalMemoryGraph()

    def test_bulk_added_nodes_after_index_creation_are_also_indexed(self):
        self.graph.add_node_index('foo')
        node_defns = []
        for i in range(1000):
            node_defns.append(dict(foo=i))
        self.graph.bulk_add_node(node_defns)
        # a bit hackish : get node index instance :
        index = self.graph._node_indexes[0]
        assert(index.estimate(None, dict(foo=42)) == 1)

    def test_bulk_added_edges_after_index_creation_are_also_indexed(self):
        self.graph.add_edge_index('foo')
        n1 = self.graph.add_node()
        n2 = self.graph.add_node()
        n3 = self.graph.add_node()
        n4 = self.graph.add_node()
        edge_defns = []
        edge_defns.append((n1, n2, dict(foo=1)))
        edge_defns.append((n2, n3, dict(foo=2)))
        edge_defns.append((n3, n4, dict(foo=3)))
        self.graph.bulk_add_edge(edge_defns)
        # a bit hackish : get edge index instance :
        index = self.graph._edge_indexes[0]
        assert(index.estimate(None, dict(foo=1)) == 1)
class TLocalMemoryGraph(FillMethod, CommonMethods):  # Not using "Test" in names so that I can import it elsewhere without running tests 2 times

    def setup(self):
        from grapheekdb.backends.data.localmem import LocalMemoryGraph
        self.graph = LocalMemoryGraph()
        self.fill()

    # Type checking :

    def test_self_graph_is_a_base_graph(self):
        assert(isinstance(self.graph, BaseGraph))

    def test_n1_is_a_node(self):
        assert(isinstance(self.n1, Node))

    def test_n2_is_a_node(self):
        assert(isinstance(self.n2, Node))

    def test_e1_is_a_edge(self):
        assert(isinstance(self.e1, Edge))

    # Type checking after a lookup on iterator :

    def test_node_lookup_result_are_node(self):
        items = self.graph.V(foo__gte=1)
        assert(items.count())  # Just want to be sure that the assert clause in loop will be executed
        for item in items:
            assert(isinstance(item, Node))

    def test_edge_lookup_result_are_edge(self):
        items = self.graph.E(common__gte=1)
        assert(items.count())  # Just want to be sure that the assert clause in loop will be executed
        for item in items:
            assert(isinstance(item, Edge))

    # Test invalid traversals :

    def test_edge_oe_is_forbidden(self):
        exception_raised = False
        try:
            self.graph.E(label='knows').outE().count()
        except GrapheekNoSuchTraversalException:
            exception_raised = True
        assert(exception_raised)

    def test_edge_ie_is_forbidden(self):
        exception_raised = False
        try:
            self.graph.E(label='knows').inE().count()
        except GrapheekNoSuchTraversalException:
            exception_raised = True
        assert(exception_raised)

    def test_edge_be_is_forbidden(self):
        exception_raised = False
        try:
            self.graph.E(label='knows').bothE().count()
        except GrapheekNoSuchTraversalException:
            exception_raised = True
        assert(exception_raised)

    def test_node_oeoe(self):
        exception_raised = False
        try:
            self.n2.outE().outE()
        except GrapheekNoSuchTraversalException:
            exception_raised = True
        assert(exception_raised)

    def test_node_oeie(self):
        exception_raised = False
        try:
            self.n3.outE().inE()
        except GrapheekNoSuchTraversalException:
            exception_raised = True
        assert(exception_raised)

    def test_node_ieoe(self):
        exception_raised = False
        try:
            self.n2.inE().outE()
        except GrapheekNoSuchTraversalException:
            exception_raised = True
        assert(exception_raised)

    def test_node_ieie(self):
        exception_raised = False
        try:
            self.n2.inE().outE()
        except GrapheekNoSuchTraversalException:
            exception_raised = True
        assert(exception_raised)

    def test_node_beoe(self):
        exception_raised = False
        try:
            self.n2.bothE().outE()
        except GrapheekNoSuchTraversalException:
            exception_raised = True
        assert(exception_raised)

    def test_node_beie(self):
        exception_raised = False
        try:
            self.n2.bothE().outE()
        except GrapheekNoSuchTraversalException:
            exception_raised = True
        assert(exception_raised)

    # test invalid lookups

    def test_filter_node_invalid_no_such_clause_1(self):
        exception_raised = False
        try:
            self.graph.V(foo__xxxxxxx=1).count()
        except GrapheekInvalidLookupException:
            exception_raised = True
        assert(exception_raised)

    def test_filter_node_invalid_no_such_clause_2(self):
        exception_raised = False
        try:
            self.n1.outV(foo__xxxxxxx=1).count()
        except GrapheekInvalidLookupException:
            exception_raised = True
        assert(exception_raised)

    def test_filter_node_invalid_no_subfield_lookup_1(self):
        # This can evolve in the future : implementing lookup like field__attr__gt=1
        # shouldn't be that complicated
        exception_raised = False
        try:
            self.graph.V(foo__subfoo__gt=1).count()
        except GrapheekSubLookupNotImplementedException:
            exception_raised = True
        assert(exception_raised)

    def test_filter_node_invalid_no_subfield_lookup_2(self):
        # This can evolve in the future : implementing lookup like field__attr__gt=1
        # shouldn't be that complicated
        exception_raised = False
        try:
            self.graph.V(name='Raf').bothV(foo__xxx__gt=1).count()
        except GrapheekSubLookupNotImplementedException:
            exception_raised = True
        assert(exception_raised)

    def test_index_on_lookup(self):
        count = CHUNK_SIZE + 100
        data = []
        for i in range(count):
            data.append({'document_id': i})
        nodes = self.graph.bulk_add_node(data)
        self.graph.add_node_index('document_id')
        # Adding edges from self.n1 to every node :
        data = []
        for node in nodes:
            data.append((self.n1, node, {}))
        self.graph.bulk_add_edge(data)
        # Now doing the test :
        self.graph.V(name='Raf').outV(document_id=500).count()

    def test_node_index_is_used_to_reduce_calls(self):
        import cProfile
        import pstats
        count = CHUNK_SIZE + 100
        """
        Dunno how to do this in a clean way
        So this test is a workaround : I'm just checking that there's far less calls
        when an index exists
        TODO : This test must be recoded or removed
        """
        # Getting exact lookup total calls WITHOUT index :
        data = []
        for i in range(count):
            data.append({'document_id': i})
            #self.graph.add_node(document_id=i)
        self.graph.bulk_add_node(data)
        # Now the hack part :
        pr = cProfile.Profile()
        pr.enable()
        self.graph.V(document_id=500).count()
        pr.disable()
        stats = pstats.Stats(pr)
        total_calls_without_index = stats.total_calls
        # Getting exact lookup total calls WITH index :
        self.graph.add_node_index('document_id')
        pr = cProfile.Profile()
        pr.enable()
        self.graph.V(document_id=500).count()
        pr.disable()
        stats = pstats.Stats(pr)
        total_calls_with_index = stats.total_calls
        assert(total_calls_with_index < total_calls_without_index / 20)  # 20 is totally subjective

    def test_node_index_is_used_to_reduce_calls_in_lookup(self):
        import cProfile
        import pstats
        count = CHUNK_SIZE + 100
        """
        Dunno how to do this in a clean way
        So this test is a workaround : I'm just checking that there's far less calls
        when an index exists
        TODO : This test must be recoded or removed
        """
        # Getting exact lookup total calls WITHOUT index :
        data = []
        for i in range(count):
            data.append(dict(document_id=i))
        self.graph.bulk_add_node(data)
        # Now the hack part :
        pr = cProfile.Profile()
        pr.enable()
        self.graph.V(document_id=500).count()
        pr.disable()
        stats = pstats.Stats(pr)
        total_calls_without_index = stats.total_calls
        # Getting exact lookup total calls WITH index :
        self.graph.add_node_index('document_id')
        pr = cProfile.Profile()
        pr.enable()
        self.graph.V(document_id=500).count()
        pr.disable()
        stats = pstats.Stats(pr)
        total_calls_with_index = stats.total_calls
        assert(total_calls_with_index < total_calls_without_index / 20)  # 20 is totally subjective

    def test_edge_index_is_used(self):
        import cProfile
        import pstats
        count = CHUNK_SIZE + 100
        """
        Dunno how to do this in a clean way
        So this test is a workaround : I'm just checking that there's far less calls
        when an index exists
        TODO : This test must be recoded or removed
        """
        # Getting exact lookup total calls WITHOUT index :
        nodes = []
        data = []
        for i in range(count):
            data.append(dict(document_id=i))
        nodes = self.graph.bulk_add_node(data)
        counter = 0
        data = []
        for start, end in zip(nodes, nodes[1:]):
            data.append((start, end, dict(counter=counter)))
            counter += 1
        self.graph.bulk_add_edge(data)
        # Now the hack part :
        pr = cProfile.Profile()
        pr.enable()
        self.graph.E(counter=500).count()
        pr.disable()
        stats = pstats.Stats(pr)
        total_calls_without_index = stats.total_calls
        # Getting exact lookup total calls WITH index :
        self.graph.add_edge_index('counter')
        pr = cProfile.Profile()
        pr.enable()
        self.graph.E(counter=500).count()
        pr.disable()
        stats = pstats.Stats(pr)
        total_calls_with_index = stats.total_calls
        assert(total_calls_with_index < total_calls_without_index / 20)  # 20 is totally subjective

    def test_index_multi_chunk(self):
        count = CHUNK_SIZE + 100
        data = []
        # I just want to create an chunk with more than <CHUNK_SIZE> entity ids
        # to check that index is still working
        for i in range(count):
            data.append(dict(document_id=1))
        nodes = self.graph.bulk_add_node(data)
        node_ids = [node.get_id() for node in nodes]
        self.graph.add_node_index('document_id')
        # Forcing index use :
        good_index = None
        for index in self.graph._node_indexes:
            if index._fields == ['document_id']:
                good_index = index
        assert(good_index is not None)
        indexed_ids = list(index.ids(None, dict(document_id=1)))
        assert(set(node_ids) == set(indexed_ids))
        assert(len(indexed_ids) > CHUNK_SIZE)  # maybe overkill (?)

    def test_index_tells_that_it_is_incompetent(self):
        unknown_key = 'qklnfmqvnmkljnsfklsmlkdfjqslkdfjlkj'
        # Ensure there's at least an index
        self.graph.add_node_index('foobar')
        # Asking each index an estimation :
        for index in self.graph._node_indexes:
            assert(index.estimate(None, {unknown_key: 1}) == -1)

    def test_index_returns_none_when_it_is_incompetent(self):
        from grapheekdb.lib.exceptions import GrapheekIncompetentIndexException
        unknown_key = 'qklnfmqvnmkljnsfklsmlkdfjqslkdfjlkj'
        # Ensure there's at least an index
        self.graph.add_node_index('foobar')
        # Asking each index an estimation :
        raise_exceptions = []
        for index in self.graph._node_indexes:
            try:
                list(index.ids(None, {unknown_key: 1}))
            except GrapheekIncompetentIndexException:
                raise_exceptions.append(True)
            else:
                raise_exceptions.append(False)
        assert(all(raise_exceptions))

    def test_index_returns_empty_list_when_no_element_match_criteria(self):
        # Ensure there's at least an index
        self.graph.add_node_index('foobar')
        # Forcing index use :
        good_index = None
        for index in self.graph._node_indexes:
            if index._fields == ['foobar']:
                good_index = index
        assert(good_index is not None)
        assert(list(good_index.ids(None, dict(foobar=1))) == [])

    def test_index_estimate_returns_zero_when_no_element_match_criteria(self):
        # Ensure there's at least an index
        self.graph.add_node_index('foobar')
        # Forcing index use :
        good_index = None
        for index in self.graph._node_indexes:
            if index._fields == ['foobar']:
                good_index = index
        assert(good_index is not None)
        assert(good_index.estimate(None, dict(foobar=1)) == 0)

    # test index removal

    def test_index_removal(self):
        import cProfile
        import pstats
        count = CHUNK_SIZE + 1000

        # Getting exact lookup total calls WITHOUT index :
        data = []
        for i in range(count):
            data.append(dict(document_id=i))
        self.graph.bulk_add_node(data)
        # Now the hack part :
        pr = cProfile.Profile()
        pr.enable()
        self.graph.V(document_id=500).count()
        pr.disable()
        stats = pstats.Stats(pr)
        total_calls_without_index1 = stats.total_calls
        # Now adding and removing index, check that total calls is the same :
        self.graph.add_node_index('document_id')
        self.graph.remove_node_index('document_id')
        pr = cProfile.Profile()
        pr.enable()
        self.graph.V(document_id=500).count()
        pr.disable()
        stats = pstats.Stats(pr)
        total_calls_without_index2 = stats.total_calls
        # Hum... bad test :(
        assert(0.90 < float(total_calls_without_index1) / float(total_calls_without_index2) < 1.10)

    def test_unexisting_index_removal(self):
        exception_raised = False
        try:
            self.graph.remove_node_index()
        except GrapheekIndexRemovalFailedException:
            exception_raised = True
        assert(exception_raised)

    # test index multiple addition

    def test_node_index_multiple_addition(self):
        self.graph.add_node_index('foo')
        exception_raised = False
        try:
            self.graph.add_node_index('foo')
        except GrapheekIndexAlreadyExistsException:
            exception_raised = True
        assert(exception_raised)

    def test_edge_index_multiple_addition(self):
        self.graph.add_edge_index('foo')
        exception_raised = False
        try:
            self.graph.add_edge_index('foo')
        except GrapheekIndexAlreadyExistsException:
            exception_raised = True
        assert(exception_raised)

    # Test edge addition by node ids :

    def test_add_edge_by_node_ids(self):
        data = dict(foo='test_add_edge_by_node_ids')
        edge = self.graph.add_edge_by_ids(self.n1.get_id(), self.n2.get_id(), **data)
        # check it had the same effect as usual add_edge :
        assert(self.n1 in edge.inV())
        assert(self.n2 in edge.outV())
        assert(data == edge.data())

    # Testing bulk_add_edge_by_id

    def test_bulk_add_edge_by_ids(self):
        # Just checking that no exception raised and that edge count increased :
        count_before = self.graph.E().count()
        self.graph.bulk_add_edge_by_ids([(self.n3.get_id(), self.n1.get_id(), {}), (self.n2.get_id(), self.n2.get_id(), {})])
        count_after = self.graph.E().count()
        assert(count_after == count_before + 2)

    # Test update_data method :

    def test_data_write_existing_field_using_update_data(self):
        # update_data is useful for client
        from grapheekdb.backends.data.keys import KIND_VERTEX
        self.graph.update_data(KIND_VERTEX, self.n1.get_id(), 'foo', 10)