def test_triples_numeric_literals_factory_split(self):
        """Test splitting a TriplesNumericLiteralsFactory object."""
        # Slightly larger number of triples to guarantee split can find coverage of all entities and relations.
        triples_larger = np.array(
            [
                ["peter", "likes", "chocolate_cake"],
                ["chocolate_cake", "isA", "dish"],
                ["susan", "likes", "chocolate_cake"],
                ["susan", "likes", "pizza"],
                ["peter", "likes", "susan"],
                ["peter", "isA", "person"],
                ["susan", "isA", "person"],
            ],
            dtype=str,
        )

        triples_numeric_literal_factory = TriplesNumericLiteralsFactory(
            triples=triples_larger,
            numeric_triples=numeric_triples,
        )

        left, right = triples_numeric_literal_factory.split()

        self.assertIsInstance(left, TriplesNumericLiteralsFactory)
        self.assertIsInstance(right, TriplesNumericLiteralsFactory)

        assert (left.numeric_literals ==
                triples_numeric_literal_factory.numeric_literals).all()
        assert (right.numeric_literals ==
                triples_numeric_literal_factory.numeric_literals).all()
    def test_create_lcwa_instances(self):
        """Test creating LCWA instances."""
        factory = TriplesNumericLiteralsFactory(triples=triples, numeric_triples=numeric_triples)
        instances = factory.create_lcwa_instances()

        id_peter = factory.entity_to_id['peter']
        id_age = instances.literals_to_id['/lit/hasAge']
        id_height = instances.literals_to_id['/lit/hasHeight']
        id_num_children = instances.literals_to_id['/lit/hasChildren']

        self.assertEqual(instances.numeric_literals[id_peter, id_age], 30)
        self.assertEqual(instances.numeric_literals[id_peter, id_height], 185)
        self.assertEqual(instances.numeric_literals[id_peter, id_num_children], 2)

        id_susan = factory.entity_to_id['susan']
        id_age = instances.literals_to_id['/lit/hasAge']
        id_height = instances.literals_to_id['/lit/hasHeight']
        id_num_children = instances.literals_to_id['/lit/hasChildren']

        self.assertEqual(instances.numeric_literals[id_susan, id_age], 28)
        self.assertEqual(instances.numeric_literals[id_susan, id_height], 170)
        self.assertEqual(instances.numeric_literals[id_susan, id_num_children], 0)

        id_chocolate_cake = factory.entity_to_id['chocolate_cake']
        id_age = instances.literals_to_id['/lit/hasAge']
        id_height = instances.literals_to_id['/lit/hasHeight']
        id_num_children = instances.literals_to_id['/lit/hasChildren']

        self.assertEqual(instances.numeric_literals[id_chocolate_cake, id_age], 0)
        self.assertEqual(instances.numeric_literals[id_chocolate_cake, id_height], 0)
        self.assertEqual(instances.numeric_literals[id_chocolate_cake, id_num_children], 0)

        # Check if multilabels are working correctly
        self.assertTrue((instance_mapped_triples == instances.mapped_triples.cpu().detach().numpy()).all())
        self.assertTrue(all(all(instance_labels[i] == instances.labels[i]) for i in range(len(instance_labels))))
    def test_create_lcwa_instances(self):
        """Test creating LCWA instances."""
        factory = TriplesNumericLiteralsFactory(
            triples=triples, numeric_triples=numeric_triples)
        instances = factory.create_lcwa_instances()

        id_peter = factory.entity_to_id["peter"]
        id_age = instances.literals_to_id["/lit/hasAge"]
        id_height = instances.literals_to_id["/lit/hasHeight"]
        id_num_children = instances.literals_to_id["/lit/hasChildren"]

        self.assertEqual(instances.numeric_literals[id_peter, id_age], 30)
        self.assertEqual(instances.numeric_literals[id_peter, id_height], 185)
        self.assertEqual(instances.numeric_literals[id_peter, id_num_children],
                         2)

        id_susan = factory.entity_to_id["susan"]
        id_age = instances.literals_to_id["/lit/hasAge"]
        id_height = instances.literals_to_id["/lit/hasHeight"]
        id_num_children = instances.literals_to_id["/lit/hasChildren"]

        self.assertEqual(instances.numeric_literals[id_susan, id_age], 28)
        self.assertEqual(instances.numeric_literals[id_susan, id_height], 170)
        self.assertEqual(instances.numeric_literals[id_susan, id_num_children],
                         0)

        id_chocolate_cake = factory.entity_to_id["chocolate_cake"]
        id_age = instances.literals_to_id["/lit/hasAge"]
        id_height = instances.literals_to_id["/lit/hasHeight"]
        id_num_children = instances.literals_to_id["/lit/hasChildren"]

        self.assertEqual(instances.numeric_literals[id_chocolate_cake, id_age],
                         0)
        self.assertEqual(
            instances.numeric_literals[id_chocolate_cake, id_height], 0)
        self.assertEqual(
            instances.numeric_literals[id_chocolate_cake, id_num_children], 0)

        # Check if multilabels are working correctly
        self.assertTrue((instance_mapped_triples == instances.pairs).all())
        for i, exp in enumerate(instance_labels):
            self.assertTrue(
                (exp == instances.compressed[i].nonzero()[-1]).all())