def test_load_dump(self): with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: builder = testing.DummyMnist(data_dir=tmp_dir) builder2 = dill.loads(dill.dumps(builder)) self.assertEqual(builder.name, builder2.name) self.assertEqual(builder.version, builder2.version)
def test_stats_not_restored_gcs_overwritten(self): with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: # If split are different that the one restored, stats should be recomputed builder = testing.DummyMnist(data_dir=tmp_dir) self.assertEqual(builder.info.splits["train"].num_examples, 20)
def test_show_examples(self, mock_fig): with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: builder = testing.DummyMnist(data_dir=tmp_dir) builder.download_and_prepare() ds = builder.as_dataset(split="train") visualization.show_examples(builder.info, ds)
def test_stats_restored_from_gcs(self): with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: builder = testing.DummyMnist(data_dir=tmp_dir) self.assertEqual(builder.info.splits["train"].num_examples, 20)