Exemplo n.º 1
0
 def setUp(self):
     torch.manual_seed(1234)
     np.random.seed(1234)
     self.multi_dataset = MultiDatasetLoader()
     self.multi_dataset._num_datasets = 3
     self.multi_dataset.current_index = 0
     numbers_dataset_a = NumbersDataset(4, "a")
     numbers_dataset_b = NumbersDataset(40, "b")
     numbers_dataset_c = NumbersDataset(4000, "c")
     self.multi_dataset._datasets = [
         numbers_dataset_a,
         numbers_dataset_b,
         numbers_dataset_c,
     ]
     self.multi_dataset._loaders = [
         self._get_dataloader(numbers_dataset_a),
         self._get_dataloader(numbers_dataset_b),
         self._get_dataloader(numbers_dataset_c),
     ]
     self.multi_dataset.current_loader = self.multi_dataset.loaders[0]
     self.multi_dataset.config = {
         "training": {
             "dataset_size_proportional_sampling": True,
             "max_epochs": None
         }
     }
     self.multi_dataset._per_dataset_lengths = [4, 40, 4000]
     self.multi_dataset._total_length = sum(
         self.multi_dataset._per_dataset_lengths)
 def setUp(self):
     torch.manual_seed(1234)
     np.random.seed(1234)
     self.multi_dataset = MultiDatasetLoader()
     self.multi_dataset._num_datasets = 3
     self.multi_dataset.current_index = 0
     numbers_dataset_a = NumbersDataset(4, "a")
     numbers_dataset_b = NumbersDataset(40, "b")
     numbers_dataset_c = NumbersDataset(4000, "c")
     self.multi_dataset.dataset_list = ["a", "b", "c"]
     self.multi_dataset._loaders = {
         "a": self._get_dataloader(numbers_dataset_a),
         "b": self._get_dataloader(numbers_dataset_b),
         "c": self._get_dataloader(numbers_dataset_c),
     }
     self.original_config = registry.get("config")
     registry.register(
         "config",
         OmegaConf.create({
             "training": {
                 "dataset_size_proportional_sampling": True,
                 "max_epochs": None,
             }
         }),
     )
     self.multi_dataset._per_dataset_lengths = [4, 40, 4000]
     self.multi_dataset._total_length = 4044
     self.multi_dataset._total_length = sum(
         self.multi_dataset._per_dataset_lengths)
class DatasetLoader:
    def __init__(self, config):
        self.config = config

    def load_datasets(self):
        self.train_dataset = MultiDatasetLoader("train")
        self.val_dataset = MultiDatasetLoader("val")
        self.test_dataset = MultiDatasetLoader("test")

        self.train_dataset.load(self.config)
        self.val_dataset.load(self.config)
        self.test_dataset.load(self.config)

        # If number of datasets is one, this will return the first loader
        self.train_loader = self.train_dataset
        self.val_loader = self.val_dataset
        self.test_loader = self.test_dataset

        self.mapping = {
            "train": self.train_dataset,
            "val": self.val_dataset,
            "test": self.test_dataset,
        }

        self.test_reporter = None
        self.should_not_log = self.config.training.should_not_log

    @property
    def dataset_config(self):
        return self._dataset_config

    @dataset_config.setter
    def dataset_config(self, config):
        self._dataset_config = config

    def get_config(self):
        return self._dataset_config

    def get_test_reporter(self, dataset_type):
        dataset = getattr(self, f"{dataset_type}_dataset")
        return TestReporter(dataset)

    def prepare_batch(self, batch, *args, **kwargs):
        batch = SampleList(batch)
        return self.mapping[batch.dataset_type].prepare_batch(batch)

    def verbose_dump(self, report, *args, **kwargs):
        if self.config.training.verbose_dump:
            dataset_type = report.dataset_type
            self.mapping[dataset_type].verbose_dump(report, *args, **kwargs)

    def seed_sampler(self, dataset_type, seed):
        dataset = getattr(self, f"{dataset_type}_dataset")
        dataset.seed_sampler(seed)
Exemplo n.º 4
0
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.batch_size = get_batch_size()

        self.train_loader = MultiDatasetLoader("train")
        self.val_loader = MultiDatasetLoader("val")
        self.test_loader = MultiDatasetLoader("test")

        self.train_loader.load(self.config)
        self.val_loader.load(self.config)
        self.test_loader.load(self.config)
    def load_datasets(self):
        self.train_dataset = MultiDatasetLoader("train")
        self.val_dataset = MultiDatasetLoader("val")
        self.test_dataset = MultiDatasetLoader("test")

        self.train_dataset.load(self.config)
        self.val_dataset.load(self.config)
        self.test_dataset.load(self.config)

        # If number of datasets is one, this will return the first loader
        self.train_loader = self.train_dataset
        self.val_loader = self.val_dataset
        self.test_loader = self.test_dataset

        self.mapping = {
            "train": self.train_dataset,
            "val": self.val_dataset,
            "test": self.test_dataset,
        }

        self.test_reporter = None
        self.should_not_log = self.config.training.should_not_log
Exemplo n.º 6
0
class LightningDataModule(pl.LightningDataModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.batch_size = get_batch_size()

        self.train_loader = MultiDatasetLoader("train")
        self.val_loader = MultiDatasetLoader("val")
        self.test_loader = MultiDatasetLoader("test")

        self.train_loader.load(self.config)
        self.val_loader.load(self.config)
        self.test_loader.load(self.config)

    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.val_loader

    def test_dataloader(self):
        return self.test_loader
Exemplo n.º 7
0
class TestMultiDatasetLoader(unittest.TestCase):
    def setUp(self):
        torch.manual_seed(1234)
        np.random.seed(1234)
        self.multi_dataset = MultiDatasetLoader()
        self.multi_dataset._num_datasets = 3
        self.multi_dataset.current_index = 0
        numbers_dataset_a = NumbersDataset(4, "a")
        numbers_dataset_b = NumbersDataset(40, "b")
        numbers_dataset_c = NumbersDataset(4000, "c")
        self.multi_dataset._datasets = [
            numbers_dataset_a,
            numbers_dataset_b,
            numbers_dataset_c,
        ]
        self.multi_dataset._loaders = [
            self._get_dataloader(numbers_dataset_a),
            self._get_dataloader(numbers_dataset_b),
            self._get_dataloader(numbers_dataset_c),
        ]
        self.multi_dataset.current_loader = self.multi_dataset.loaders[0]
        self.multi_dataset.config = {
            "training": {
                "dataset_size_proportional_sampling": True,
                "max_epochs": None
            }
        }
        self.multi_dataset._per_dataset_lengths = [4, 40, 4000]
        self.multi_dataset._total_length = sum(
            self.multi_dataset._per_dataset_lengths)

    def _get_dataloader(self, dataset):
        return DataLoader(dataset=dataset, batch_size=4, num_workers=0)

    def test_proportional_sampling(self):
        self.multi_dataset._infer_dataset_probabilities()

        count = 0
        count_c = 0
        for batch in self.multi_dataset:
            batch = self.multi_dataset.prepare_batch(batch)
            if "c" in batch:
                count_c += 1
            count += 1
            if count == 100:
                break

        # Expect more than 95 c's at least as the len for c is very high
        self.assertTrue(count_c >= 98)

        count = 0
        count_epoch = 0
        counter = Counter()
        for _ in range(1):
            for batch in self.multi_dataset:
                batch = self.multi_dataset.prepare_batch(batch)
                counter[list(batch.keys())[0]] += 1
                count += 1
            count_epoch += 1
        # Expect epoch to be completed
        self.assertEqual(count_epoch, 1)
        # Expect each dataset to be full iterated
        self.assertEqual(count, self.multi_dataset._total_length // 4)
        self.assertEqual(counter, Counter({"a": 1, "b": 10, "c": 1000}))

    def test_equal_sampling(self):
        self.multi_dataset.config["training"][
            "dataset_size_proportional_sampling"] = False
        self.multi_dataset._infer_dataset_probabilities()

        count = 0
        count_c = 0
        for batch in self.multi_dataset:
            batch = self.multi_dataset.prepare_batch(batch)
            if "c" in batch:
                count_c += 1
            count += 1
            if count == 100:
                break

        self.assertTrue(count_c <= 34)

        # Epoch will never finish for this case, so test upto proportional sampling's
        # epoch length + some extra
        for batch in self.multi_dataset:
            batch = self.multi_dataset.prepare_batch(batch)
            count += 1
            if count > self.multi_dataset._total_length // 4 + 100:
                break

        # The test should reach at this stage and should not be finished at
        # epoch length
        self.assertTrue(count > self.multi_dataset._total_length // 4 + 100)
Exemplo n.º 8
0
class DatasetLoader:
    def __init__(self, config):
        # TODO: Remove in next version
        warnings.warn(
            "DatasetLoader has been deprecated and will be removed in future versions. "
            "Please use mmf.datasets.multi_datamodule.MultiDataModule instead.",
            DeprecationWarning,
            stacklevel=2,
        )
        self.config = config

    def load_datasets(self):
        self.train_dataset = MultiDatasetLoader("train")
        self.val_dataset = MultiDatasetLoader("val")
        self.test_dataset = MultiDatasetLoader("test")

        self.train_dataset.load(self.config)
        self.val_dataset.load(self.config)
        self.test_dataset.load(self.config)

        # If number of datasets is one, this will return the first loader
        self.train_loader = self.train_dataset
        self.val_loader = self.val_dataset
        self.test_loader = self.test_dataset

        self.mapping = {
            "train": self.train_dataset,
            "val": self.val_dataset,
            "test": self.test_dataset,
        }

        self.test_reporter = None
        self.should_not_log = self.config.training.should_not_log

    @property
    def dataset_config(self):
        return self._dataset_config

    @dataset_config.setter
    def dataset_config(self, config):
        self._dataset_config = config

    def get_config(self):
        return self._dataset_config

    def get_test_reporter(self, dataset_type):
        dataset = getattr(self, f"{dataset_type}_dataset")
        datamodules = build_multiple_datamodules(dataset.dataset_list,
                                                 self.config.dataset_config)
        test_reporter_config = self._get_test_reporter_config()
        return build_test_reporter(datamodules, test_reporter_config,
                                   dataset_type)

    def _get_test_reporter_config(self):
        dataset_name = list(self.config.dataset_config.keys())[0]
        dataset_config = self.config.dataset_config.get(dataset_name)
        if hasattr(dataset_config, "get"):
            return dataset_config.get("test_reporter_config", None)
        else:
            return None

    def prepare_batch(self, batch, *args, **kwargs):
        batch = SampleList(batch)
        return self.mapping[batch.dataset_type].prepare_batch(batch)

    def verbose_dump(self, report, *args, **kwargs):
        if self.config.training.verbose_dump:
            dataset_type = report.dataset_type
            self.mapping[dataset_type].verbose_dump(report, *args, **kwargs)

    def seed_sampler(self, dataset_type, seed):
        dataset = getattr(self, f"{dataset_type}_dataset")
        dataset.seed_sampler(seed)