def setUp(self) -> None: """Set up deterioration tests.""" self.generator = ensure_torch_random_state(42) self.reference = Nations() self.num_training_triples = self.reference.training.num_triples self.num_triples = (self.reference.training.num_triples + self.reference.testing.num_triples + self.reference.validation.num_triples)
def _pre_instantiation_hook( self, kwargs: MutableMapping[str, Any]) -> MutableMapping[str, Any]: # create triples factory with inverse relations kwargs = super()._pre_instantiation_hook(kwargs=kwargs) kwargs["triples_factory"] = self.factory = Nations( create_inverse_triples=True).training return kwargs
def setUp(self) -> None: """Set up the test case with a triples factory and model.""" _, self.generator, _ = set_random_seed(42) dataset = Nations(create_inverse_triples=self.create_inverse_triples) self.factory = dataset.training self.model = self.model_cls(self.factory, embedding_dim=self.embedding_dim, **(self.model_kwargs or {})).to_device_()
def test_custom_dataset_instance(self): """Test passing a pre-instantiated dataset to HPO.""" hpo_pipeline_result = self._help_test_hpo( study_name="HPO with custom dataset instance", dataset=Nations( ), # mock a "custom" dataset by using one already available ) # Since custom data was passed, we can't store any of this self.assertNotIn("dataset", hpo_pipeline_result.study.user_attrs) self.assertNotIn("training", hpo_pipeline_result.study.user_attrs) self.assertNotIn("testing", hpo_pipeline_result.study.user_attrs) self.assertNotIn("validation", hpo_pipeline_result.study.user_attrs)
class TestDeterioration(unittest.TestCase): """Tests for deterioration workflow.""" def setUp(self) -> None: """Set up deterioration tests.""" self.generator = ensure_torch_random_state(42) self.reference = Nations() self.num_training_triples = self.reference.training.num_triples self.num_triples = (self.reference.training.num_triples + self.reference.testing.num_triples + self.reference.validation.num_triples) def test_deteriorate(self): """Test deterioration on integer values for ``n``.""" for n in [1, 2, 5, 10, 50, 100, 500, 1000]: with self.subTest(n=n): derived = self.reference.deteriorate( n=n, random_state=self.generator) self._help_check(derived) self.assertEqual( n, splits_steps(self.reference._tup(), derived._tup())) self.assertEqual(1 - n / self.num_triples, self.reference.similarity(derived), msg="similarity") self.assertEqual(1 - n / self.num_triples, derived.similarity(self.reference), msg="similarity") def test_deteriorate_frac(self): """Test deterioration on fractional values for ``n``.""" for n_frac in [ 1 / self.num_training_triples, 2 / self.num_training_triples, 5 / self.num_training_triples, 0.1, 0.2, 0.3, ]: n = int(n_frac * self.num_training_triples) with self.subTest(n=n, n_frac=n_frac): derived = self.reference.deteriorate( n=n, random_state=self.generator) self._help_check(derived) self.assertEqual( n, splits_steps(self.reference._tup(), derived._tup()), msg="steps", ) self.assertEqual(1 - n / self.num_triples, self.reference.similarity(derived), msg="similarity") self.assertEqual(1 - n / self.num_triples, derived.similarity(self.reference), msg="similarity") def _help_check(self, derived: Dataset): self.assertIsNotNone(derived.validation) self.assertEqual(self.num_training_triples, self.reference.training.num_triples) self.assertEqual( self.num_triples, sum(( derived.training.num_triples, derived.testing.num_triples, derived.validation.num_triples, )), msg="different number of total triples", ) self.assertLess(derived.training.num_triples, self.reference.training.num_triples)