Пример #1
0
    def test_metadata(self):
        """Test metadata passing for triples factories."""
        t = Nations().training
        self.assertEqual(NATIONS_TRAIN_PATH, t.metadata['path'])
        self.assertEqual(
            (
                f'TriplesFactory(num_entities=14, num_relations=55, num_triples=1592,'
                f' inverse_triples=False, path="{NATIONS_TRAIN_PATH}")'
            ),
            repr(t),
        )

        entities = ['poland', 'ussr']
        x = t.new_with_restriction(entities=entities)
        entities_ids = t.entities_to_ids(entities=entities)
        self.assertEqual(NATIONS_TRAIN_PATH, x.metadata['path'])
        self.assertEqual(
            (
                f'TriplesFactory(num_entities=14, num_relations=55, num_triples=37,'
                f' inverse_triples=False, entity_restriction={repr(entities_ids)}, path="{NATIONS_TRAIN_PATH}")'
            ),
            repr(x),
        )

        relations = ['negativebehavior']
        v = t.new_with_restriction(relations=relations)
        relations_ids = t.relations_to_ids(relations=relations)
        self.assertEqual(NATIONS_TRAIN_PATH, x.metadata['path'])
        self.assertEqual(
            (
                f'TriplesFactory(num_entities=14, num_relations=55, num_triples=29,'
                f' inverse_triples=False, path="{NATIONS_TRAIN_PATH}", relation_restriction={repr(relations_ids)})'
            ),
            repr(v),
        )

        w = t.clone_and_exchange_triples(t.triples[0:5], keep_metadata=False)
        self.assertIsInstance(w, TriplesFactory)
        self.assertNotIn('path', w.metadata)
        self.assertEqual(
            'TriplesFactory(num_entities=14, num_relations=55, num_triples=5, inverse_triples=False)',
            repr(w),
        )

        y, z = t.split()
        self.assertEqual(NATIONS_TRAIN_PATH, y.metadata['path'])
        self.assertEqual(NATIONS_TRAIN_PATH, z.metadata['path'])
Пример #2
0
    def test_new_with_restriction(self):
        """Test new_with_restriction()."""
        example_relation_restriction = {
            'economicaid',
            'dependent',
        }
        example_entity_restriction = {
            'brazil',
            'burma',
            'china',
        }
        for inverse_triples in (True, False):
            original_triples_factory = Nations(
                create_inverse_triples=inverse_triples, ).training
            for entity_restriction in (None, example_entity_restriction):
                for relation_restriction in (None,
                                             example_relation_restriction):
                    # apply restriction
                    restricted_triples_factory = original_triples_factory.new_with_restriction(
                        entities=entity_restriction,
                        relations=relation_restriction,
                    )
                    # check that the triples factory is returned as is, if and only if no restriction is to apply
                    no_restriction_to_apply = (entity_restriction is None and
                                               relation_restriction is None)
                    equal_factory_object = (id(restricted_triples_factory) ==
                                            id(original_triples_factory))
                    assert no_restriction_to_apply == equal_factory_object

                    # check that inverse_triples is correctly carried over
                    assert (original_triples_factory.create_inverse_triples ==
                            restricted_triples_factory.create_inverse_triples)

                    # verify that the label-to-ID mapping has not been changed
                    assert original_triples_factory.entity_to_id == restricted_triples_factory.entity_to_id
                    assert original_triples_factory.relation_to_id == restricted_triples_factory.relation_to_id

                    # verify that triples have been filtered
                    if entity_restriction is not None:
                        present_relations = set(
                            restricted_triples_factory.triples[:, 0]).union(
                                restricted_triples_factory.triples[:, 2])
                        assert set(entity_restriction).issuperset(
                            present_relations)

                    if relation_restriction is not None:
                        present_relations = set(
                            restricted_triples_factory.triples[:, 1])
                        exp_relations = set(relation_restriction)
                        if original_triples_factory.create_inverse_triples:
                            exp_relations = exp_relations.union(
                                map(
                                    original_triples_factory.
                                    relation_to_inverse.get, exp_relations))
                        assert exp_relations.issuperset(present_relations)