def _infer_dataset_probabilities(self): from mmf.utils.configuration import get_global_config training = get_global_config("training") self._dataset_probabilities = [ 1 / self._num_datasets for _ in range(self.num_datasets) ] self._proportional_sampling = training.get( "dataset_size_proportional_sampling", True ) multitasking = get_global_config("multitasking") if multitasking is None: multitasking = {} multitasking_enabled = multitasking.get("enabled", False) assert ( self._proportional_sampling is True or training.get("max_epochs", None) is None ), "Epoch based training can only be used with size proportional sampling" assert not (self._proportional_sampling and multitasking_enabled), ( "Multitasking (manually-specified) per-dataset ratios cannot be used " "with size proportional sampling" ) if ( len(self.loaders) > 0 and self.current_loader.dataset.dataset_type != "train" ): # If it is val or test, it needs to be all datasets need to be # fully iterated as metrics will be calculated in eval mode # over complete datasets self._proportional_sampling = True if self._proportional_sampling is True and len(self._per_dataset_lengths) > 0: self._dataset_probabilities = self._per_dataset_lengths[:] self._dataset_probabilities = [ prob / self._total_length for prob in self._dataset_probabilities ] if multitasking_enabled and self._dataset_type == "train": sampling_ratios = multitasking.get("sampling_ratios", {}) probabilities = [] for dataset in self.dataset_list: assert ( dataset in sampling_ratios ), f"{dataset} must be specified in multitasking.sampling_ratios" probabilities.append(sampling_ratios[dataset]) # normalize the sampling ratios to sum up to 1 prob_sum = sum(probabilities) assert all(prob >= 0 for prob in probabilities) and prob_sum > 0, ( "multitasking.sampling_ratios must be all non-negative and at least " "one of them needs to be positive." ) self._dataset_probabilities = [prob / prob_sum for prob in probabilities] logger.info("Using per-dataset sampling probabilities:") for dataset, prob in zip(self.dataset_list, self._dataset_probabilities): logger.info(f"\t{dataset}: {prob}")
def build_dataloader_and_sampler( dataset_instance: torch.utils.data.Dataset, datamodule_config: DictConfig ) -> Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]: """Builds and returns a dataloader along with its sample Args: dataset_instance (torch.utils.data.Dataset): Instance of dataset for which dataloader has to be created datamodule_config (omegaconf.DictConfig): Datamodule configuration; required for infering params for dataloader Returns: Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]: Tuple of Dataloader and Sampler instance """ from mmf.common.batch_collator import BatchCollator training_config = get_global_config("training") # Support params coming in from dataloader params other_args = { "num_workers": datamodule_config.get( "num_workers", training_config.get("num_workers", 4) ), "pin_memory": datamodule_config.get( "pin_memory", training_config.get("pin_memory", False) ), "shuffle": datamodule_config.get("shuffle", None), "batch_size": datamodule_config.get("batch_size", None), } # IterableDataset returns batches directly, so no need to add Sampler # or batch size as user is expected to control those. This is a fine # assumption for now to not support single item based IterableDataset # as it will add unnecessary complexity and config parameters # to the codebase if not isinstance(dataset_instance, torch.utils.data.IterableDataset): other_args = _add_extra_args_for_dataloader(dataset_instance, other_args) else: other_args.pop("shuffle") loader = torch.utils.data.DataLoader( dataset=dataset_instance, collate_fn=BatchCollator( dataset_instance.dataset_name, dataset_instance.dataset_type ), drop_last=is_xla(), # see also MultiDatasetLoader.__len__ **other_args, ) if is_xla(): device = xm.xla_device() loader = xla_pl.MpDeviceLoader(loader, device) if other_args["num_workers"] >= 0: # Suppress leaking semaphore warning os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" loader.dataset_type = dataset_instance.dataset_type return loader, other_args.get("sampler", None)
def _check_not_epoch_training(self): """ Having this allows easy override of the strategy in non-MMF use cases """ training = get_global_config("training") assert ( training.get("max_epochs", None) is None ), f"{self.__class__.__name__} doesn't make sense with epoch based training"
def _infer_dataset_probabilities(self): from mmf.utils.configuration import get_global_config training = get_global_config("training") proportional_sampling = training.get( "dataset_size_proportional_sampling", True) if proportional_sampling is True: strategy = iteration_strategies.SizeProportionalIterationStrategy self._iteration_strategy = strategy(OmegaConf.create(), self.loaders) else: self._iteration_strategy = iteration_strategies.RandomIterationStrategy( OmegaConf.create(), self.loaders) multitasking = get_global_config("multitasking") multitasking_enabled = multitasking.get("enabled", False) assert ( proportional_sampling is True or training.get("max_epochs", None) is None ), "Epoch based training can only be used with size proportional sampling" assert not (proportional_sampling and multitasking_enabled), ( "Multitasking (manually-specified) per-dataset ratios cannot be used " "with size proportional sampling") if multitasking_enabled and "sampling_ratios" in multitasking: self._iteration_strategy = iteration_strategies.RatiosIterationStrategy( OmegaConf.create({ "sampling_ratios": multitasking.sampling_ratios, "datasets": self._given_datasets, }), self._loaders, ) elif proportional_sampling is True: strategy = iteration_strategies.SizeProportionalIterationStrategy self._iteration_strategy = strategy(OmegaConf.create(), self.loaders) else: self._iteration_strategy = iteration_strategies.RandomIterationStrategy( OmegaConf.create(), self.loaders)
def get_batch_size(): from mmf.utils.configuration import get_global_config batch_size = get_global_config("training.batch_size") world_size = get_world_size() if batch_size % world_size != 0: raise RuntimeError("Batch size {} must be divisible by number " "of GPUs {} used.".format(batch_size, world_size)) return batch_size // world_size
def get_batch_size(): from mmf.utils.configuration import get_global_config batch_size = get_global_config("training.batch_size") world_size = get_world_size() batch_size_per_device = get_global_config("training.batch_size_per_device") if batch_size_per_device is not None: logger.info( f"training.batch_size_per_device has been used as {batch_size_per_device} " + "This will override training.batch_size and set the global batch size to " + f"{batch_size_per_device} x {world_size} = " + f"{batch_size_per_device * world_size}") batch_size = batch_size_per_device * world_size if batch_size % world_size != 0: raise RuntimeError("Batch size {} must be divisible by number " "of GPUs {} used.".format(batch_size, world_size)) return batch_size // world_size
def zoo_config_path(self): if self.ZOO_CONFIG_PATH is None: self.ZOO_CONFIG_PATH = get_global_config("env.dataset_zoo") return self.ZOO_CONFIG_PATH
def build_dataloader_and_sampler( dataset_instance: torch.utils.data.Dataset, datamodule_config: DictConfig ) -> Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]: """Builds and returns a dataloader along with its sample Args: dataset_instance (torch.utils.data.Dataset): Instance of dataset for which dataloader has to be created datamodule_config (omegaconf.DictConfig): Datamodule configuration; required for infering params for dataloader Returns: Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]: Tuple of Dataloader and Sampler instance """ from mmf.common.batch_collator import BatchCollator training_config = get_global_config("training") # Support params coming in from dataloader params other_args = { "num_workers": datamodule_config.get( "num_workers", training_config.get("num_workers", 4) ), "pin_memory": datamodule_config.get( "pin_memory", training_config.get("pin_memory", False) ), "shuffle": datamodule_config.get("shuffle", None), "batch_size": datamodule_config.get("batch_size", None), } if version.parse(torch.__version__) >= version.parse("1.8"): # only use persistent workers in PyTorch 1.8 or higher # (PyTorch 1.7 also has this option but doesn't support it correctly due to # https://github.com/pytorch/pytorch/issues/48370) other_args["persistent_workers"] = ( datamodule_config.get( "persistent_workers", training_config.get("persistent_workers", True) ), ) if other_args["persistent_workers"] and other_args["num_workers"] == 0: logger.warning( "persistent_workers cannot be used together with num_workers == 0; " "setting persistent_workers to False" ) other_args["persistent_workers"] = False # IterableDataset returns batches directly, so no need to add Sampler # or batch size as user is expected to control those. This is a fine # assumption for now to not support single item based IterableDataset # as it will add unnecessary complexity and config parameters # to the codebase if not isinstance(dataset_instance, torch.utils.data.IterableDataset): other_args = _add_extra_args_for_dataloader(dataset_instance, other_args) else: other_args.pop("shuffle") # Set drop_last=True when using XLA to have constant batch size. # In this case we also need to set drop_last=True in DistributedSampler. loader = torch.utils.data.DataLoader( dataset=dataset_instance, collate_fn=BatchCollator( dataset_instance.dataset_name, dataset_instance.dataset_type ), drop_last=is_xla(), # see also MultiDatasetLoader.__len__ **other_args, ) if is_xla(): device = xm.xla_device() loader = xla_pl.MpDeviceLoader(loader, device) if other_args["num_workers"] >= 0: # Suppress leaking semaphore warning os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" loader.dataset_type = dataset_instance.dataset_type return loader, other_args.get("sampler", None)
def get_class_weight(): from mmf.utils.configuration import get_global_config class_weight = get_global_config("training.class_weight") return class_weight
def _get_test_reporter_config(self): from mmf.utils.configuration import get_global_config return get_global_config("evaluation.reporter")