コード例 #1
0
ファイル: test_regularizers.py プロジェクト: sunny1401/pykeen
    def _help_test_regularizer(self,
                               regularizer: Regularizer,
                               n_tensors: int = 3):
        # ensure regularizer is on correct device
        regularizer = regularizer.to(self.device)

        self.assertFalse(regularizer.updated)
        self.assertEqual(0.0, regularizer.regularization_term.item())

        # After first update, should change the term
        first_tensors = [
            rand(10, 10, generator=self.generator, device=self.device)
            for _ in range(n_tensors)
        ]
        regularizer.update(*first_tensors)
        self.assertTrue(regularizer.updated)
        self.assertNotEqual(0.0, regularizer.regularization_term.item())
        term = regularizer.regularization_term.clone()

        # After second update, no change should happen
        second_tensors = [
            rand(10, 10, generator=self.generator, device=self.device)
            for _ in range(n_tensors)
        ]
        regularizer.update(*second_tensors)
        self.assertTrue(regularizer.updated)
        self.assertEqual(term, regularizer.regularization_term)

        regularizer.reset()
        self.assertFalse(regularizer.updated)
        self.assertEqual(0.0, regularizer.regularization_term.item())
コード例 #2
0
 def setUp(self) -> None:
     """Set up the test case."""
     self.generator = torch.random.manual_seed(seed=42)
     self.device = resolve_device()
     self.kwargs = {"weight": 0.5, "epsilon": 1e-5}
     self.instance = self.cls(
         **(self.kwargs or {}),
     ).to(self.device)
     self.num_entities = 10
     self.num_relations = 5
     self.entities_weight = rand(self.num_entities, 10, generator=self.generator, device=self.device)
     self.relations_weight = rand(self.num_relations, 20, generator=self.generator, device=self.device)
     self.normal_vector_weight = rand(self.num_relations, 20, generator=self.generator, device=self.device)
コード例 #3
0
ファイル: test_regularizers.py プロジェクト: sunny1401/pykeen
    def test_update(self):
        """Test update function of TransHRegularizer."""
        # Tests that exception will be thrown when more than or less than three tensors are passed
        with self.assertRaises(KeyError) as context:
            self.instance.update(
                self.entities_weight,
                self.normal_vector_weight,
                self.relations_weight,
                rand(self.num_entities,
                     10,
                     generator=self.generator,
                     device=self.device),
            )
            self.assertTrue(
                'Expects exactly three tensors' in context.exception)

            self.instance.update(
                self.entities_weight,
                self.normal_vector_weight,
            )
            self.assertTrue(
                'Expects exactly three tensors' in context.exception)

        # Test that regularization term is computed correctly
        self.instance.update(self.entities_weight, self.normal_vector_weight,
                             self.relations_weight)
        expected_term = self._expected_penalty()
        weight = self.kwargs.get('weight')
        self.assertAlmostEqual(self.instance.term.item(),
                               weight * expected_term.item())
コード例 #4
0
ファイル: cases.py プロジェクト: pzq7025/pykeen
    def test_update(self) -> None:
        """Test method `update`."""
        # Generate random tensors
        a = rand(self.batch_size, 10, generator=self.generator, device=self.device)
        b = rand(self.batch_size, 20, generator=self.generator, device=self.device)

        # Call update
        self.instance.update(a, b)

        # check shape
        self.assertEqual((1,), self.instance.term.shape)

        # compute expected term
        exp_penalties = torch.stack([self._expected_penalty(x) for x in (a, b)])
        expected_term = torch.sum(exp_penalties).view(1) * self.instance.weight
        assert expected_term.shape == (1,)

        self.assertAlmostEqual(self.instance.term.item(), expected_term.item())
コード例 #5
0
ファイル: cases.py プロジェクト: pzq7025/pykeen
    def test_forward(self) -> None:
        """Test the regularizer's `forward` method."""
        # Generate random tensor
        x = rand(self.batch_size, 10, generator=self.generator, device=self.device)

        # calculate penalty
        penalty = self.instance.forward(x=x)

        # check shape
        assert penalty.numel() == 1

        # check value
        expected_penalty = self._expected_penalty(x=x)
        if expected_penalty is None:
            logging.warning(f'{self.__class__.__name__} did not override `_expected_penalty`.')
        else:
            assert (expected_penalty == penalty).all()