def test_metadata(self): """Test metadata passing for triples factories.""" t = Nations().training self.assertEqual(NATIONS_TRAIN_PATH, t.metadata['path']) self.assertEqual( ( f'TriplesFactory(num_entities=14, num_relations=55, num_triples=1592,' f' inverse_triples=False, path="{NATIONS_TRAIN_PATH}")' ), repr(t), ) entities = ['poland', 'ussr'] x = t.new_with_restriction(entities=entities) entities_ids = t.entities_to_ids(entities=entities) self.assertEqual(NATIONS_TRAIN_PATH, x.metadata['path']) self.assertEqual( ( f'TriplesFactory(num_entities=14, num_relations=55, num_triples=37,' f' inverse_triples=False, entity_restriction={repr(entities_ids)}, path="{NATIONS_TRAIN_PATH}")' ), repr(x), ) relations = ['negativebehavior'] v = t.new_with_restriction(relations=relations) relations_ids = t.relations_to_ids(relations=relations) self.assertEqual(NATIONS_TRAIN_PATH, x.metadata['path']) self.assertEqual( ( f'TriplesFactory(num_entities=14, num_relations=55, num_triples=29,' f' inverse_triples=False, path="{NATIONS_TRAIN_PATH}", relation_restriction={repr(relations_ids)})' ), repr(v), ) w = t.clone_and_exchange_triples(t.triples[0:5], keep_metadata=False) self.assertIsInstance(w, TriplesFactory) self.assertNotIn('path', w.metadata) self.assertEqual( 'TriplesFactory(num_entities=14, num_relations=55, num_triples=5, inverse_triples=False)', repr(w), ) y, z = t.split() self.assertEqual(NATIONS_TRAIN_PATH, y.metadata['path']) self.assertEqual(NATIONS_TRAIN_PATH, z.metadata['path'])
def test_new_with_restriction(self): """Test new_with_restriction().""" example_relation_restriction = { 'economicaid', 'dependent', } example_entity_restriction = { 'brazil', 'burma', 'china', } for inverse_triples in (True, False): original_triples_factory = Nations( create_inverse_triples=inverse_triples, ).training for entity_restriction in (None, example_entity_restriction): for relation_restriction in (None, example_relation_restriction): # apply restriction restricted_triples_factory = original_triples_factory.new_with_restriction( entities=entity_restriction, relations=relation_restriction, ) # check that the triples factory is returned as is, if and only if no restriction is to apply no_restriction_to_apply = (entity_restriction is None and relation_restriction is None) equal_factory_object = (id(restricted_triples_factory) == id(original_triples_factory)) assert no_restriction_to_apply == equal_factory_object # check that inverse_triples is correctly carried over assert (original_triples_factory.create_inverse_triples == restricted_triples_factory.create_inverse_triples) # verify that the label-to-ID mapping has not been changed assert original_triples_factory.entity_to_id == restricted_triples_factory.entity_to_id assert original_triples_factory.relation_to_id == restricted_triples_factory.relation_to_id # verify that triples have been filtered if entity_restriction is not None: present_relations = set( restricted_triples_factory.triples[:, 0]).union( restricted_triples_factory.triples[:, 2]) assert set(entity_restriction).issuperset( present_relations) if relation_restriction is not None: present_relations = set( restricted_triples_factory.triples[:, 1]) exp_relations = set(relation_restriction) if original_triples_factory.create_inverse_triples: exp_relations = exp_relations.union( map( original_triples_factory. relation_to_inverse.get, exp_relations)) assert exp_relations.issuperset(present_relations)