Beispiel #1
0
    def setUpClass(cls):
        client = grakn.client.GraknClient(uri="localhost:48555")
        cls.session = client.session(keyspace="test_schema")

        entity_query = "match $x isa company, has name 'Google'; get;"
        cls._tx = cls.session.transaction().write()

        neighbour_sample_sizes = (4, 3)

        sampling_method = ordered.ordered_sample

        samplers = []
        for sample_size in neighbour_sample_sizes:
            samplers.append(
                samp.Sampler(sample_size,
                             sampling_method,
                             limit=sample_size * 2))

        grakn_thing = next(cls._tx.query(entity_query)).get('x')

        thing = neighbour.build_thing(grakn_thing)

        context_builder = builder.ContextBuilder(samplers)

        cls.context = context_builder.build(cls._tx, thing)
Beispiel #2
0
    def test_neighbour_finder_called_with_root_and_neighbour_ids(self):

        tx_mock = mock.Mock(grakn.client.Transaction)
        sampler = mock.Mock(samp.Sampler)
        sampler.return_value = mocks.gen([
            mock.MagicMock(neighbour.Connection,
                           role_label="employmee",
                           role_direction=1,
                           neighbour_thing=mock.MagicMock(neighbour.Thing,
                                                          id="1")),
            mock.MagicMock(neighbour.Connection,
                           role_label="@has-name-owner",
                           role_direction=1,
                           neighbour_thing=mock.MagicMock(neighbour.Thing,
                                                          id="3")),
        ])
        sampler2 = mock.Mock(samp.Sampler)
        sampler2.return_value = []

        starting_thing = mock.MagicMock(neighbour.Thing, id="0")
        mock_neighbour_finder = mock.MagicMock(neighbour.NeighbourFinder)

        context_builder = builder.ContextBuilder(
            [sampler, sampler2], neighbour_finder=mock_neighbour_finder)

        # The call to assess
        context_builder.build(tx_mock, starting_thing)

        print(mock_neighbour_finder.find.mock_calls)
        mock_neighbour_finder.find.assert_has_calls([
            mock.call("0", tx_mock),
            mock.call("1", tx_mock),
            mock.call("3", tx_mock)
        ])
Beispiel #3
0
    def test_build_context_for_1_hop(self):

        starting_thing = neighbour.Thing("0", "person", "entity")

        samplers = [samp.Sampler(2, ordered.ordered_sample, limit=2)]
        context_builder = builder.ContextBuilder(
            samplers, neighbour_finder=mocks.MockNeighbourFinder())

        thing_context = context_builder.build(
            mock.Mock(grakn.client.Transaction), starting_thing)

        expected_context = {
            0: [
                builder.Node((0, ),
                             neighbour.Thing("1",
                                             "name",
                                             "attribute",
                                             data_type='string',
                                             value='Sundar Pichai'), "has",
                             neighbour.NEIGHBOUR_PLAYS),
                builder.Node((1, ),
                             neighbour.Thing("2", "employment", "relation"),
                             "employee", neighbour.TARGET_PLAYS),
            ],
            1: [builder.Node((), neighbour.Thing("0", "person", "entity"))],
        }
        self.assertEqual(expected_context, thing_context)
Beispiel #4
0
    def setUp(self):
        entity_query = "match $x isa company, has name 'Google'; get;"
        uri = "localhost:48555"
        keyspace = "test_schema"
        client = grakn.Grakn(uri=uri)
        session = client.session(keyspace=keyspace)
        self._tx = session.transaction(grakn.TxType.WRITE)

        neighbour_sample_sizes = (4, 3)

        sampling_method = ordered.ordered_sample

        samplers = []
        for sample_size in neighbour_sample_sizes:
            samplers.append(samp.Sampler(sample_size, sampling_method, limit=sample_size * 2))

        grakn_things = [answermap.get('x') for answermap in list(self._tx.query(entity_query))]

        things = [neighbour.build_thing(grakn_thing) for grakn_thing in grakn_things]

        context_builder = builder.ContextBuilder(samplers)

        self._neighbourhood_depths = [context_builder.build(self._tx, thing) for thing in things]

        self._neighbour_roles = builder.convert_thing_contexts_to_neighbours(self._neighbourhood_depths)

        self._flattened = flatten_tree(self._neighbour_roles)
Beispiel #5
0
def _neighbourhood_traverser_factory(neighbour_sample_sizes):
    sampling_method = ordered.ordered_sample

    samplers = []
    for sample_size in neighbour_sample_sizes:
        samplers.append(samp.Sampler(sample_size, sampling_method, limit=sample_size * 2))

    context_builder = builder.ContextBuilder(samplers)
    return context_builder
Beispiel #6
0
    def __init__(
        self,
        neighbour_sample_sizes,
        features_size,
        example_concepts_features_size,
        aggregated_size,
        embedding_size,
        schema_encoding_transaction,
        batch_size,
        embedding_normalisation=tf.nn.l2_normalize,
        neighbour_sampling_method=ordered.ordered_sample,
        neighbour_sampling_limit_factor=1,
        formatters={'neighbour_value_date': preprocess.datetime_to_unixtime},
        features_to_exclude=()):

        self._embedding_normalisation = embedding_normalisation
        self.embedding_size = embedding_size
        self.aggregated_size = aggregated_size
        self.neighbour_sample_sizes = neighbour_sample_sizes

        self.feature_sizes = [features_size] * len(self.neighbour_sample_sizes)
        self.feature_sizes[-1] = example_concepts_features_size
        print(f'feature sizes: {self.feature_sizes}')

        self._schema_encoding_transaction = schema_encoding_transaction
        self._encode = encode.Encoder(self._schema_encoding_transaction)

        self.batch_size = batch_size
        self._formatters = formatters
        self._features_to_exclude = features_to_exclude

        traversal_samplers = []
        for sample_size in neighbour_sample_sizes:
            traversal_samplers.append(
                sample.Sampler(sample_size,
                               neighbour_sampling_method,
                               limit=int(sample_size *
                                         neighbour_sampling_limit_factor)))

        self._array_builder = array.ArrayConverter(neighbour_sample_sizes)

        self._context_builder = builder.ContextBuilder(traversal_samplers)

        self._embed = embed.Embedder(
            self.feature_sizes,
            self.aggregated_size,
            self.embedding_size,
            self.neighbour_sample_sizes,
            normalisation=self._embedding_normalisation)

        features_to_exclude = {
            feat_name: None
            for feat_name in self._features_to_exclude
        }
        self.neighbourhood_dataset, self.array_placeholders = preprocess.build_dataset(
            self.neighbour_sample_sizes, **features_to_exclude)
Beispiel #7
0
    def test_build_context_for_0_hop(self):
        starting_thing = neighbour.Thing("0", "person", "entity")

        samplers = []
        context_builder = builder.ContextBuilder(samplers, neighbour_finder=mocks.MockNeighbourFinder())

        thing_context = context_builder.build(mock.Mock(grakn.client.Transaction), starting_thing)
        expected_context = {
            0: [builder.Node((), neighbour.Thing("0", "person", "entity"))],
        }
        self.assertEqual(expected_context, thing_context)
Beispiel #8
0
    def test_input_output(self):

        neighbour_sample_sizes = (2, 3)

        samplers = [lambda x: x for sample_size in neighbour_sample_sizes]

        starting_thing = neighbour.Thing("0", "person", "entity")

        context_builder = builder.ContextBuilder(samplers, neighbour_finder=mocks.MockNeighbourFinder())

        thing_context = context_builder.build(self._tx, starting_thing)

        self.assertEqual(thing_context, mocks.mock_traversal_output())
Beispiel #9
0
    def test_build_context_batch(self):
        batch = [neighbour.Thing("0", "person", "entity"), neighbour.Thing("0", "person", "entity")]

        samplers = []
        context_builder = builder.ContextBuilder(samplers, neighbour_finder=mocks.MockNeighbourFinder())

        thing_context = context_builder.build_batch(mock.Mock(grakn.client.Session), batch)
        expected_context_batch = [{
            0: [builder.Node((), neighbour.Thing("0", "person", "entity"))],
        },
            {
                0: [builder.Node((), neighbour.Thing("0", "person", "entity"))],
            }
        ]
        self.assertEqual(expected_context_batch, thing_context)
Beispiel #10
0
    def test_neighbour_finder_called_with_root_node_id(self):

        tx_mock = mock.Mock(grakn.client.Transaction)
        sampler = mock.Mock(samp.Sampler)
        sampler.return_value = []

        starting_thing = mock.MagicMock(neighbour.Thing, id="0")
        mock_neighbour_finder = mock.MagicMock(neighbour.NeighbourFinder)

        context_builder = builder.ContextBuilder(
            [sampler], neighbour_finder=mock_neighbour_finder)

        # The call to assess
        context_builder.build(tx_mock, starting_thing)

        mock_neighbour_finder.find.assert_called_once_with("0", tx_mock)
Beispiel #11
0
    def setUp(self):
        self._tx = self.session.transaction(grakn.TxType.WRITE)
        neighbour_sample_sizes = (2, 3)

        samplers = [lambda x: x for sample_size in neighbour_sample_sizes]

        starting_thing = neighbour.Thing("0", "person", "entity")
        things = [starting_thing]

        context_builder = builder.ContextBuilder(samplers, neighbour_finder=mocks.MockNeighbourFinder())

        self._neighbourhood_depths = [context_builder.build(self._tx, thing) for thing in things]

        self._neighbour_roles = builder.convert_thing_contexts_to_neighbours(self._neighbourhood_depths)

        self._flattened = flatten_tree(self._neighbour_roles)
Beispiel #12
0
    def test_input_output_integration(self):
        """
        Runs using real samplers
        :return:
        """

        sampling_method = ordered.ordered_sample

        samplers = [samp.Sampler(2, sampling_method, limit=2), samp.Sampler(3, sampling_method, limit=1)]

        starting_thing = neighbour.Thing("0", "person", "entity")

        context_builder = builder.ContextBuilder(samplers, neighbour_finder=mocks.MockNeighbourFinder())

        thing_context = context_builder.build(self._tx, starting_thing)

        self.assertEqual(thing_context, mocks.mock_traversal_output())
Beispiel #13
0
    def test_build_context_for_3_hop(self):

        starting_thing = neighbour.Thing("0", "person", "entity")

        samplers = [samp.Sampler(2, ordered.ordered_sample, limit=2), samp.Sampler(2, ordered.ordered_sample, limit=2),
                    samp.Sampler(3, ordered.ordered_sample, limit=3)]
        context_builder = builder.ContextBuilder(samplers, neighbour_finder=mocks.MockNeighbourFinder())

        thing_context = context_builder.build(mock.Mock(grakn.client.Transaction), starting_thing)

        expected_context = {
            3: [builder.Node((), neighbour.Thing("0", "person", "entity"))],
            2: [builder.Node((0,), neighbour.Thing("1", "name", "attribute", data_type='string', value='Sundar Pichai'),
                             "has", neighbour.NEIGHBOUR_PLAYS),
                builder.Node((1,), neighbour.Thing("2", "employment", "relation"), "employee", neighbour.TARGET_PLAYS),
                ],
            1: [builder.Node((0, 0), neighbour.Thing("0", "person", "entity"), "has", neighbour.TARGET_PLAYS),
                # Note that (0, 1) is reversed compared to the natural expectation
                builder.Node((0, 1), neighbour.Thing("3", "company", "entity"), "employer", neighbour.NEIGHBOUR_PLAYS),
                builder.Node((1, 1), neighbour.Thing("0", "person", "entity"), "employee", neighbour.NEIGHBOUR_PLAYS),
                ],
            0: [builder.Node((0, 0, 0), neighbour.Thing("1", "name", "attribute", data_type='string', value='Sundar Pichai'),
                             "has", neighbour.NEIGHBOUR_PLAYS),
                builder.Node((1, 0, 0), neighbour.Thing("2", "employment", "relation"), "employee", neighbour.TARGET_PLAYS),
                builder.Node((0, 0, 1), neighbour.Thing("4", "name", "attribute", data_type='string', value='Google'),
                             "has", neighbour.NEIGHBOUR_PLAYS),
                builder.Node((1, 0, 1), neighbour.Thing("4", "name", "attribute", data_type='string', value='Google'),
                             "has", neighbour.NEIGHBOUR_PLAYS),
                builder.Node((2, 0, 1), neighbour.Thing("4", "name", "attribute", data_type='string', value='Google'),
                             "has", neighbour.NEIGHBOUR_PLAYS),
                builder.Node((0, 1, 1), neighbour.Thing("1", "name", "attribute", data_type='string', value='Sundar Pichai'),
                             "has", neighbour.NEIGHBOUR_PLAYS),
                builder.Node((1, 1, 1), neighbour.Thing("2", "employment", "relation"), "employee", neighbour.TARGET_PLAYS),
            ]

        }
        self.assertEqual(expected_context, thing_context)