def test_combined_iterator_reporting(self, mnist_factory): iterator_train, iterator_train_meta = mnist_factory.get_dataset_iterator( split="train") iterator_test, iterator_test_meta = mnist_factory.get_dataset_iterator( split="test") meta_train = MetaFactory.get_dataset_meta( identifier="id x", dataset_name="MNIST", dataset_tag="train", iterator_meta=iterator_train_meta) meta_test = MetaFactory.get_dataset_meta( identifier="id x", dataset_name="MNIST", dataset_tag="train", iterator_meta=iterator_test_meta) informed_iterator_train = InformedDatasetFactory.get_dataset_iterator( iterator_train, meta_train) informed_iterator_test = InformedDatasetFactory.get_dataset_iterator( iterator_test, meta_test) meta_combined = MetaFactory.get_dataset_meta_from_existing( informed_iterator_train.dataset_meta, dataset_tag="full") iterator = InformedDatasetFactory.get_combined_dataset_iterator( [informed_iterator_train, informed_iterator_test], meta_combined) report = DatasetIteratorReportGenerator.generate_report(iterator) assert report.length == 70000 and report.sub_reports[ 0].length == 60000 and report.sub_reports[1].length == 10000 assert not report.sub_reports[ 0].sub_reports and not report.sub_reports[1].sub_reports
def dataset_meta(self) -> DatasetMeta: iterator_meta = MetaFactory.get_iterator_meta(sample_pos=0, target_pos=1, tag_pos=2) return MetaFactory.get_dataset_meta(identifier="identifier_1", dataset_name="TEST DATASET", dataset_tag="train", iterator_meta=iterator_meta)
def iterator(self) -> str: targets = [1]*100 + [2]*200 + [3]*300 sequence_targets = torch.Tensor(targets) sequence_samples = torch.ones_like(sequence_targets) iterator = SequenceDatasetIterator([sequence_samples, sequence_targets]) iterator_meta = MetaFactory.get_iterator_meta(sample_pos=0, target_pos=1, tag_pos=1) meta = MetaFactory.get_dataset_meta(identifier="dataset id", dataset_name="dataset", dataset_tag="full", iterator_meta=iterator_meta) return InformedDatasetFactory.get_dataset_iterator(iterator, meta)
def test_plain_iterator_reporting(self, mnist_factory): iterator, iterator_meta = mnist_factory.get_dataset_iterator( split="train") dataset_meta = MetaFactory.get_dataset_meta( identifier="id x", dataset_name="MNIST", dataset_tag="train", iterator_meta=iterator_meta) informed_iterator = InformedDatasetIterator(iterator, dataset_meta) report = DatasetIteratorReportGenerator.generate_report( informed_iterator) print(report) assert report.length == 60000 and not report.sub_reports