def setUp(self) -> None: """Set up the test case.""" self.generator = torch.random.manual_seed(seed=42) self.device = resolve_device() self.regularizer_kwargs = {'weight': .5, 'epsilon': 1e-5} self.regularizer = TransHRegularizer( device=self.device, **(self.regularizer_kwargs or {}), ) self.num_entities = 10 self.num_relations = 5 self.entities_weight = torch.rand(self.num_entities, 10, device=self.device, generator=self.generator) self.relations_weight = torch.rand(self.num_relations, 20, device=self.device, generator=self.generator) self.normal_vector_weight = torch.rand(self.num_relations, 20, device=self.device, generator=self.generator)
def test_transh_regularizer(self): """Test the TransH regularizer only updates once.""" self.assertNotIn("apply_only_once", TransH.regularizer_default_kwargs) regularizer = TransHRegularizer( **TransH.regularizer_default_kwargs, ) self._help_test_regularizer(regularizer)
class TransHRegularizerTest(unittest.TestCase): """Test the TransH regularizer.""" generator: torch.Generator device: torch.device regularizer_kwargs: Dict num_entities: int num_relations: int entities_weight: torch.Tensor relations_weight: torch.Tensor normal_vector_weight: torch.Tensor def setUp(self) -> None: """Set up the test case.""" self.generator = torch.random.manual_seed(seed=42) self.device = resolve_device() self.regularizer_kwargs = {'weight': .5, 'epsilon': 1e-5} self.regularizer = TransHRegularizer( device=self.device, **(self.regularizer_kwargs or {}), ) self.num_entities = 10 self.num_relations = 5 self.entities_weight = torch.rand(self.num_entities, 10, device=self.device, generator=self.generator) self.relations_weight = torch.rand(self.num_relations, 20, device=self.device, generator=self.generator) self.normal_vector_weight = torch.rand(self.num_relations, 20, device=self.device, generator=self.generator) def test_update(self): """Test update function of TransHRegularizer.""" # Tests that exception will be thrown when more than or less than three tensors are passed with self.assertRaises(KeyError) as context: self.regularizer.update( self.entities_weight, self.normal_vector_weight, self.relations_weight, torch.rand(self.num_entities, 10, device=self.device, generator=self.generator), ) self.assertTrue('Expects exactly three tensors' in context.exception) self.regularizer.update( self.entities_weight, self.normal_vector_weight, ) self.assertTrue('Expects exactly three tensors' in context.exception) # Test that regularization term is computed correctly self.regularizer.update(self.entities_weight, self.normal_vector_weight, self.relations_weight) expected_term = self._expected_penalty() weight = self.regularizer_kwargs.get('weight') self.assertAlmostEqual(self.regularizer.term.item(), weight * expected_term.item()) def _expected_penalty(self) -> torch.FloatTensor: # noqa: D102 # Entity soft constraint regularization_term = torch.sum(functional.relu(torch.norm(self.entities_weight, dim=-1) ** 2 - 1.0)) epsilon = self.regularizer_kwargs.get('epsilon') # # Orthogonality soft constraint d_r_n = functional.normalize(self.relations_weight, dim=-1) regularization_term += torch.sum( functional.relu(torch.sum((self.normal_vector_weight * d_r_n) ** 2, dim=-1) - epsilon), ) return regularization_term