def setUpClass(cls): client = grakn.client.GraknClient(uri="localhost:48555") cls.session = client.session(keyspace="test_schema") entity_query = "match $x isa company, has name 'Google'; get;" cls._tx = cls.session.transaction().write() neighbour_sample_sizes = (4, 3) sampling_method = ordered.ordered_sample samplers = [] for sample_size in neighbour_sample_sizes: samplers.append( samp.Sampler(sample_size, sampling_method, limit=sample_size * 2)) grakn_thing = next(cls._tx.query(entity_query)).get('x') thing = neighbour.build_thing(grakn_thing) context_builder = builder.ContextBuilder(samplers) cls.context = context_builder.build(cls._tx, thing)
def test_build_context_for_1_hop(self): starting_thing = neighbour.Thing("0", "person", "entity") samplers = [samp.Sampler(2, ordered.ordered_sample, limit=2)] context_builder = builder.ContextBuilder( samplers, neighbour_finder=mocks.MockNeighbourFinder()) thing_context = context_builder.build( mock.Mock(grakn.client.Transaction), starting_thing) expected_context = { 0: [ builder.Node((0, ), neighbour.Thing("1", "name", "attribute", data_type='string', value='Sundar Pichai'), "has", neighbour.NEIGHBOUR_PLAYS), builder.Node((1, ), neighbour.Thing("2", "employment", "relation"), "employee", neighbour.TARGET_PLAYS), ], 1: [builder.Node((), neighbour.Thing("0", "person", "entity"))], } self.assertEqual(expected_context, thing_context)
def setUp(self): entity_query = "match $x isa company, has name 'Google'; get;" uri = "localhost:48555" keyspace = "test_schema" client = grakn.Grakn(uri=uri) session = client.session(keyspace=keyspace) self._tx = session.transaction(grakn.TxType.WRITE) neighbour_sample_sizes = (4, 3) sampling_method = ordered.ordered_sample samplers = [] for sample_size in neighbour_sample_sizes: samplers.append(samp.Sampler(sample_size, sampling_method, limit=sample_size * 2)) grakn_things = [answermap.get('x') for answermap in list(self._tx.query(entity_query))] things = [neighbour.build_thing(grakn_thing) for grakn_thing in grakn_things] context_builder = builder.ContextBuilder(samplers) self._neighbourhood_depths = [context_builder.build(self._tx, thing) for thing in things] self._neighbour_roles = builder.convert_thing_contexts_to_neighbours(self._neighbourhood_depths) self._flattened = flatten_tree(self._neighbour_roles)
def test_input_output_integration(self): """ Runs using real samplers :return: """ sampling_method = ordered.ordered_sample samplers = [samp.Sampler(2, sampling_method, limit=2), samp.Sampler(3, sampling_method, limit=1)] starting_thing = neighbour.Thing("0", "person", "entity") context_builder = builder.ContextBuilder(samplers, neighbour_finder=mocks.MockNeighbourFinder()) thing_context = context_builder.build(self._tx, starting_thing) self.assertEqual(thing_context, mocks.mock_traversal_output())
def _neighbourhood_traverser_factory(neighbour_sample_sizes): sampling_method = ordered.ordered_sample samplers = [] for sample_size in neighbour_sample_sizes: samplers.append(samp.Sampler(sample_size, sampling_method, limit=sample_size * 2)) context_builder = builder.ContextBuilder(samplers) return context_builder
def __init__( self, neighbour_sample_sizes, features_size, example_concepts_features_size, aggregated_size, embedding_size, schema_encoding_transaction, batch_size, embedding_normalisation=tf.nn.l2_normalize, neighbour_sampling_method=ordered.ordered_sample, neighbour_sampling_limit_factor=1, formatters={'neighbour_value_date': preprocess.datetime_to_unixtime}, features_to_exclude=()): self._embedding_normalisation = embedding_normalisation self.embedding_size = embedding_size self.aggregated_size = aggregated_size self.neighbour_sample_sizes = neighbour_sample_sizes self.feature_sizes = [features_size] * len(self.neighbour_sample_sizes) self.feature_sizes[-1] = example_concepts_features_size print(f'feature sizes: {self.feature_sizes}') self._schema_encoding_transaction = schema_encoding_transaction self._encode = encode.Encoder(self._schema_encoding_transaction) self.batch_size = batch_size self._formatters = formatters self._features_to_exclude = features_to_exclude traversal_samplers = [] for sample_size in neighbour_sample_sizes: traversal_samplers.append( sample.Sampler(sample_size, neighbour_sampling_method, limit=int(sample_size * neighbour_sampling_limit_factor))) self._array_builder = array.ArrayConverter(neighbour_sample_sizes) self._context_builder = builder.ContextBuilder(traversal_samplers) self._embed = embed.Embedder( self.feature_sizes, self.aggregated_size, self.embedding_size, self.neighbour_sample_sizes, normalisation=self._embedding_normalisation) features_to_exclude = { feat_name: None for feat_name in self._features_to_exclude } self.neighbourhood_dataset, self.array_placeholders = preprocess.build_dataset( self.neighbour_sample_sizes, **features_to_exclude)
def test_build_context_for_3_hop(self): starting_thing = neighbour.Thing("0", "person", "entity") samplers = [samp.Sampler(2, ordered.ordered_sample, limit=2), samp.Sampler(2, ordered.ordered_sample, limit=2), samp.Sampler(3, ordered.ordered_sample, limit=3)] context_builder = builder.ContextBuilder(samplers, neighbour_finder=mocks.MockNeighbourFinder()) thing_context = context_builder.build(mock.Mock(grakn.client.Transaction), starting_thing) expected_context = { 3: [builder.Node((), neighbour.Thing("0", "person", "entity"))], 2: [builder.Node((0,), neighbour.Thing("1", "name", "attribute", data_type='string', value='Sundar Pichai'), "has", neighbour.NEIGHBOUR_PLAYS), builder.Node((1,), neighbour.Thing("2", "employment", "relation"), "employee", neighbour.TARGET_PLAYS), ], 1: [builder.Node((0, 0), neighbour.Thing("0", "person", "entity"), "has", neighbour.TARGET_PLAYS), # Note that (0, 1) is reversed compared to the natural expectation builder.Node((0, 1), neighbour.Thing("3", "company", "entity"), "employer", neighbour.NEIGHBOUR_PLAYS), builder.Node((1, 1), neighbour.Thing("0", "person", "entity"), "employee", neighbour.NEIGHBOUR_PLAYS), ], 0: [builder.Node((0, 0, 0), neighbour.Thing("1", "name", "attribute", data_type='string', value='Sundar Pichai'), "has", neighbour.NEIGHBOUR_PLAYS), builder.Node((1, 0, 0), neighbour.Thing("2", "employment", "relation"), "employee", neighbour.TARGET_PLAYS), builder.Node((0, 0, 1), neighbour.Thing("4", "name", "attribute", data_type='string', value='Google'), "has", neighbour.NEIGHBOUR_PLAYS), builder.Node((1, 0, 1), neighbour.Thing("4", "name", "attribute", data_type='string', value='Google'), "has", neighbour.NEIGHBOUR_PLAYS), builder.Node((2, 0, 1), neighbour.Thing("4", "name", "attribute", data_type='string', value='Google'), "has", neighbour.NEIGHBOUR_PLAYS), builder.Node((0, 1, 1), neighbour.Thing("1", "name", "attribute", data_type='string', value='Sundar Pichai'), "has", neighbour.NEIGHBOUR_PLAYS), builder.Node((1, 1, 1), neighbour.Thing("2", "employment", "relation"), "employee", neighbour.TARGET_PLAYS), ] } self.assertEqual(expected_context, thing_context)