def setUp(self) -> None:
     """Set up deterioration tests."""
     self.generator = ensure_torch_random_state(42)
     self.reference = Nations()
     self.num_training_triples = self.reference.training.num_triples
     self.num_triples = (self.reference.training.num_triples +
                         self.reference.testing.num_triples +
                         self.reference.validation.num_triples)
Beispiel #2
0
 def _pre_instantiation_hook(
         self, kwargs: MutableMapping[str,
                                      Any]) -> MutableMapping[str, Any]:
     # create triples factory with inverse relations
     kwargs = super()._pre_instantiation_hook(kwargs=kwargs)
     kwargs["triples_factory"] = self.factory = Nations(
         create_inverse_triples=True).training
     return kwargs
    def setUp(self) -> None:
        """Set up the test case with a triples factory and model."""
        _, self.generator, _ = set_random_seed(42)

        dataset = Nations(create_inverse_triples=self.create_inverse_triples)
        self.factory = dataset.training
        self.model = self.model_cls(self.factory,
                                    embedding_dim=self.embedding_dim,
                                    **(self.model_kwargs or {})).to_device_()
Beispiel #4
0
 def test_custom_dataset_instance(self):
     """Test passing a pre-instantiated dataset to HPO."""
     hpo_pipeline_result = self._help_test_hpo(
         study_name="HPO with custom dataset instance",
         dataset=Nations(
         ),  # mock a "custom" dataset by using one already available
     )
     # Since custom data was passed, we can't store any of this
     self.assertNotIn("dataset", hpo_pipeline_result.study.user_attrs)
     self.assertNotIn("training", hpo_pipeline_result.study.user_attrs)
     self.assertNotIn("testing", hpo_pipeline_result.study.user_attrs)
     self.assertNotIn("validation", hpo_pipeline_result.study.user_attrs)
class TestDeterioration(unittest.TestCase):
    """Tests for deterioration workflow."""
    def setUp(self) -> None:
        """Set up deterioration tests."""
        self.generator = ensure_torch_random_state(42)
        self.reference = Nations()
        self.num_training_triples = self.reference.training.num_triples
        self.num_triples = (self.reference.training.num_triples +
                            self.reference.testing.num_triples +
                            self.reference.validation.num_triples)

    def test_deteriorate(self):
        """Test deterioration on integer values for ``n``."""
        for n in [1, 2, 5, 10, 50, 100, 500, 1000]:
            with self.subTest(n=n):
                derived = self.reference.deteriorate(
                    n=n, random_state=self.generator)
                self._help_check(derived)
                self.assertEqual(
                    n, splits_steps(self.reference._tup(), derived._tup()))
                self.assertEqual(1 - n / self.num_triples,
                                 self.reference.similarity(derived),
                                 msg="similarity")
                self.assertEqual(1 - n / self.num_triples,
                                 derived.similarity(self.reference),
                                 msg="similarity")

    def test_deteriorate_frac(self):
        """Test deterioration on fractional values for ``n``."""
        for n_frac in [
                1 / self.num_training_triples,
                2 / self.num_training_triples,
                5 / self.num_training_triples,
                0.1,
                0.2,
                0.3,
        ]:
            n = int(n_frac * self.num_training_triples)
            with self.subTest(n=n, n_frac=n_frac):
                derived = self.reference.deteriorate(
                    n=n, random_state=self.generator)
                self._help_check(derived)
                self.assertEqual(
                    n,
                    splits_steps(self.reference._tup(), derived._tup()),
                    msg="steps",
                )
                self.assertEqual(1 - n / self.num_triples,
                                 self.reference.similarity(derived),
                                 msg="similarity")
                self.assertEqual(1 - n / self.num_triples,
                                 derived.similarity(self.reference),
                                 msg="similarity")

    def _help_check(self, derived: Dataset):
        self.assertIsNotNone(derived.validation)
        self.assertEqual(self.num_training_triples,
                         self.reference.training.num_triples)
        self.assertEqual(
            self.num_triples,
            sum((
                derived.training.num_triples,
                derived.testing.num_triples,
                derived.validation.num_triples,
            )),
            msg="different number of total triples",
        )
        self.assertLess(derived.training.num_triples,
                        self.reference.training.num_triples)