def load(self, **opts): self.opts = opts self._process_datasets() self.datasets = [] self.builders = [] available_datasets = self._get_available_datasets() self.total_length = 0 self.per_dataset_lengths = [] self.num_datasets = 0 for dataset in self.given_datasets: if dataset in available_datasets: builder_class = registry.get_builder_class(dataset) if builder_class is None: print("No builder class found for %s." % dataset) continue builder_instance = builder_class() if dataset in self.opts["dataset_attributes"]: attributes = self.opts["dataset_attributes"][dataset] else: self.writer.write( "Dataset %s is missing from " "dataset_attributes in config." % dataset, "error", ) sys.exit(1) dataset_type = self.opts.get("dataset_type", "train") if is_main_process(): builder_instance.build(dataset_type, attributes) synchronize() dataset_instance = builder_instance.load(dataset_type, attributes) self.builders.append(builder_instance) self.datasets.append(dataset_instance) self.per_dataset_lengths.append(len(dataset_instance)) self.total_length += len(dataset_instance) else: print( "Dataset %s is not a valid dataset for task %s. Skipping" % (dataset, self.task_name) ) self.num_datasets = len(self.datasets) self.dataset_probablities = [1 for _ in range(self.num_datasets)] sampling = self.opts.get("dataset_size_proportional_sampling", None) if sampling is True: self.dataset_probablities = self.per_dataset_lengths[:] self.dataset_probablities = [ prob / self.total_length for prob in self.dataset_probablities ] self.change_dataset()
def _load(self, dataset_type, config, *args, **kwargs): if dataset_type == 'train': self.dataset = ImageNetDataset(dataset_type, config) else: # use coco dataset to inference # therefore, in configuration file, # we should embedd configuration of coco # into it attributes = config['coco'] coco_builder = registry.get_builder_class('coco') coco_builder_instance = coco_builder() coco_builder_instance.build(dataset_type, attributes) self.dataset = coco_builder_instance.load(dataset_type, attributes) return self.dataset
def load(self, **opts): self.opts = opts self._process_datasets() self._datasets = [] self._builders = [] self._loaders = [] self._samplers = [] self._iterators = [] self._total_length = 0 self._per_dataset_lengths = [] self._num_datasets = 0 self._finished_iterators = {} self._used_once = {} for dataset in self._given_datasets: builder_class = registry.get_builder_class(dataset) if builder_class is None: print("No builder class found for %s." % dataset) continue builder_instance = builder_class() if dataset in self.opts["dataset_attributes"]: attributes = self.opts["dataset_attributes"][dataset] else: self.writer.write( "Dataset %s is missing from " "dataset_attributes in config." % dataset, "error", ) sys.exit(1) builder_instance.build(self._dataset_type, attributes) dataset_instance = builder_instance.load(self._dataset_type, attributes) if dataset_instance is None: continue loader_instance, sampler_instance = self.build_dataloader( dataset_instance, self.opts) self._builders.append(builder_instance) self._datasets.append(dataset_instance) self._loaders.append(loader_instance) self._samplers.append(sampler_instance) self._per_dataset_lengths.append(len(dataset_instance)) self._total_length += len(dataset_instance) self._num_datasets = len(self._datasets) self._dataset_probablities = [ 1 / self._num_datasets for _ in range(self._num_datasets) ] training_parameters = self._global_config.training_parameters self._proportional_sampling = training_parameters.dataset_size_proportional_sampling if self._dataset_type != "train": # If it is val or test, it needs to be all datasets need to be fully iterated # as metrics will be calculated in eval mode over complete datasets self._proportional_sampling = True if self._proportional_sampling is True: self._dataset_probablities = self._per_dataset_lengths[:] self._dataset_probablities = [ prob / self._total_length for prob in self._dataset_probablities ] self._loader_index = 0 self._chosen_dataset = self._datasets[self._loader_index] self._chosen_loader = self._loaders[self._loader_index]