Example #1
0
 def setUp(self) -> None:
     """Set up the test case."""
     self.generator = torch.random.manual_seed(seed=42)
     self.device = resolve_device()
     self.regularizer_kwargs = {'weight': .5, 'epsilon': 1e-5}
     self.regularizer = TransHRegularizer(
         device=self.device,
         **(self.regularizer_kwargs or {}),
     )
     self.num_entities = 10
     self.num_relations = 5
     self.entities_weight = torch.rand(self.num_entities, 10, device=self.device, generator=self.generator)
     self.relations_weight = torch.rand(self.num_relations, 20, device=self.device, generator=self.generator)
     self.normal_vector_weight = torch.rand(self.num_relations, 20, device=self.device, generator=self.generator)
Example #2
0
 def test_transh_regularizer(self):
     """Test the TransH regularizer only updates once."""
     self.assertNotIn("apply_only_once", TransH.regularizer_default_kwargs)
     regularizer = TransHRegularizer(
         **TransH.regularizer_default_kwargs,
     )
     self._help_test_regularizer(regularizer)
Example #3
0
class TransHRegularizerTest(unittest.TestCase):
    """Test the TransH regularizer."""

    generator: torch.Generator
    device: torch.device
    regularizer_kwargs: Dict
    num_entities: int
    num_relations: int
    entities_weight: torch.Tensor
    relations_weight: torch.Tensor
    normal_vector_weight: torch.Tensor

    def setUp(self) -> None:
        """Set up the test case."""
        self.generator = torch.random.manual_seed(seed=42)
        self.device = resolve_device()
        self.regularizer_kwargs = {'weight': .5, 'epsilon': 1e-5}
        self.regularizer = TransHRegularizer(
            device=self.device,
            **(self.regularizer_kwargs or {}),
        )
        self.num_entities = 10
        self.num_relations = 5
        self.entities_weight = torch.rand(self.num_entities, 10, device=self.device, generator=self.generator)
        self.relations_weight = torch.rand(self.num_relations, 20, device=self.device, generator=self.generator)
        self.normal_vector_weight = torch.rand(self.num_relations, 20, device=self.device, generator=self.generator)

    def test_update(self):
        """Test update function of TransHRegularizer."""
        # Tests that exception will be thrown when more than or less than three tensors are passed
        with self.assertRaises(KeyError) as context:
            self.regularizer.update(
                self.entities_weight,
                self.normal_vector_weight,
                self.relations_weight,
                torch.rand(self.num_entities, 10, device=self.device, generator=self.generator),
            )
            self.assertTrue('Expects exactly three tensors' in context.exception)

            self.regularizer.update(
                self.entities_weight,
                self.normal_vector_weight,
            )
            self.assertTrue('Expects exactly three tensors' in context.exception)

        # Test that regularization term is computed correctly
        self.regularizer.update(self.entities_weight, self.normal_vector_weight, self.relations_weight)
        expected_term = self._expected_penalty()
        weight = self.regularizer_kwargs.get('weight')
        self.assertAlmostEqual(self.regularizer.term.item(), weight * expected_term.item())

    def _expected_penalty(self) -> torch.FloatTensor:  # noqa: D102
        # Entity soft constraint
        regularization_term = torch.sum(functional.relu(torch.norm(self.entities_weight, dim=-1) ** 2 - 1.0))
        epsilon = self.regularizer_kwargs.get('epsilon')  #

        # Orthogonality soft constraint
        d_r_n = functional.normalize(self.relations_weight, dim=-1)
        regularization_term += torch.sum(
            functional.relu(torch.sum((self.normal_vector_weight * d_r_n) ** 2, dim=-1) - epsilon),
        )

        return regularization_term