def _help_test_regularizer(self, regularizer: Regularizer, n_tensors: int = 3): # ensure regularizer is on correct device regularizer = regularizer.to(self.device) self.assertFalse(regularizer.updated) self.assertEqual(0.0, regularizer.regularization_term.item()) # After first update, should change the term first_tensors = [ rand(10, 10, generator=self.generator, device=self.device) for _ in range(n_tensors) ] regularizer.update(*first_tensors) self.assertTrue(regularizer.updated) self.assertNotEqual(0.0, regularizer.regularization_term.item()) term = regularizer.regularization_term.clone() # After second update, no change should happen second_tensors = [ rand(10, 10, generator=self.generator, device=self.device) for _ in range(n_tensors) ] regularizer.update(*second_tensors) self.assertTrue(regularizer.updated) self.assertEqual(term, regularizer.regularization_term) regularizer.reset() self.assertFalse(regularizer.updated) self.assertEqual(0.0, regularizer.regularization_term.item())
def setUp(self) -> None: """Set up the test case.""" self.generator = torch.random.manual_seed(seed=42) self.device = resolve_device() self.kwargs = {"weight": 0.5, "epsilon": 1e-5} self.instance = self.cls( **(self.kwargs or {}), ).to(self.device) self.num_entities = 10 self.num_relations = 5 self.entities_weight = rand(self.num_entities, 10, generator=self.generator, device=self.device) self.relations_weight = rand(self.num_relations, 20, generator=self.generator, device=self.device) self.normal_vector_weight = rand(self.num_relations, 20, generator=self.generator, device=self.device)
def test_update(self): """Test update function of TransHRegularizer.""" # Tests that exception will be thrown when more than or less than three tensors are passed with self.assertRaises(KeyError) as context: self.instance.update( self.entities_weight, self.normal_vector_weight, self.relations_weight, rand(self.num_entities, 10, generator=self.generator, device=self.device), ) self.assertTrue( 'Expects exactly three tensors' in context.exception) self.instance.update( self.entities_weight, self.normal_vector_weight, ) self.assertTrue( 'Expects exactly three tensors' in context.exception) # Test that regularization term is computed correctly self.instance.update(self.entities_weight, self.normal_vector_weight, self.relations_weight) expected_term = self._expected_penalty() weight = self.kwargs.get('weight') self.assertAlmostEqual(self.instance.term.item(), weight * expected_term.item())
def test_update(self) -> None: """Test method `update`.""" # Generate random tensors a = rand(self.batch_size, 10, generator=self.generator, device=self.device) b = rand(self.batch_size, 20, generator=self.generator, device=self.device) # Call update self.instance.update(a, b) # check shape self.assertEqual((1,), self.instance.term.shape) # compute expected term exp_penalties = torch.stack([self._expected_penalty(x) for x in (a, b)]) expected_term = torch.sum(exp_penalties).view(1) * self.instance.weight assert expected_term.shape == (1,) self.assertAlmostEqual(self.instance.term.item(), expected_term.item())
def test_forward(self) -> None: """Test the regularizer's `forward` method.""" # Generate random tensor x = rand(self.batch_size, 10, generator=self.generator, device=self.device) # calculate penalty penalty = self.instance.forward(x=x) # check shape assert penalty.numel() == 1 # check value expected_penalty = self._expected_penalty(x=x) if expected_penalty is None: logging.warning(f'{self.__class__.__name__} did not override `_expected_penalty`.') else: assert (expected_penalty == penalty).all()