def load(self, **opts): self.opts = opts self._process_datasets() self.datasets = [] self.builders = [] available_datasets = self._get_available_datasets() self.total_length = 0 self.per_dataset_lengths = [] self.num_datasets = 0 for dataset in self.given_datasets: if dataset in available_datasets: builder_class = registry.get_builder_class(dataset) if builder_class is None: print("No builder class found for %s." % dataset) continue builder_instance = builder_class() if dataset in self.opts["dataset_attributes"]: attributes = self.opts["dataset_attributes"][dataset] else: self.writer.write( "Dataset %s is missing from " "dataset_attributes in config." % dataset, "error", ) sys.exit(1) dataset_type = self.opts.get("dataset_type", "train") if is_main_process(): builder_instance.build(dataset_type, attributes) synchronize() dataset_instance = builder_instance.load(dataset_type, attributes) self.builders.append(builder_instance) self.datasets.append(dataset_instance) self.per_dataset_lengths.append(len(dataset_instance)) self.total_length += len(dataset_instance) else: print( "Dataset %s is not a valid dataset for task %s. Skipping" % (dataset, self.task_name) ) self.num_datasets = len(self.datasets) self.dataset_probablities = [1 for _ in range(self.num_datasets)] sampling = self.opts.get("dataset_size_proportional_sampling", None) if sampling is True: self.dataset_probablities = self.per_dataset_lengths[:] self.dataset_probablities = [ prob / self.total_length for prob in self.dataset_probablities ] self.change_dataset()
def _try_download(self): is_main_process = self._is_main_process() if self._already_downloaded: return if is_main_process: self.writer.write("Fetching fastText model for OCR processing") needs_download = False if not hasattr(self.config, "model_file"): if is_main_process: warnings.warn("'model_file' key is required but missing " "from FastTextProcessor's config.") needs_download = True model_file = self.config.model_file model_file = os.path.join(get_pythia_root(), model_file) if not os.path.exists(model_file): if is_main_process: warnings.warn( "No model file present at {}.".format(model_file)) needs_download = True if needs_download: if is_main_process: self.writer.write("Downloading FastText bin", "info") model_file = self._download_model() synchronize() self._load_fasttext_model(model_file) self._already_downloaded = True
def save(self, iteration, update_best=False): # Sync all models before we start the save process synchronize() # Only save in main process if not is_main_process(): return ckpt_filepath = os.path.join(self.models_foldername, "model_%d.ckpt" % iteration) best_ckpt_filepath = os.path.join(self.ckpt_foldername, self.ckpt_prefix + "best.ckpt") best_iteration = self.trainer.early_stopping.best_monitored_iteration best_metric = self.trainer.early_stopping.best_monitored_value ckpt = { "model": self.trainer.model.state_dict(), "optimizer": self.trainer.optimizer.state_dict(), "best_iteration": best_iteration, "best_metric_value": best_metric, "config": self.config, } git_metadata_dict = self._get_vcs_fields() ckpt.update(git_metadata_dict) torch.save(ckpt, ckpt_filepath) if update_best: torch.save(ckpt, best_ckpt_filepath)
def restore(self): synchronize() self.trainer.writer.write("Restoring checkpoint") best_path = os.path.join(self.ckpt_foldername, self.ckpt_prefix + "best.ckpt") if os.path.exists(best_path): self._load(best_path, force=True)
def _load(self): self.image_path = os.path.join(self._data_folder, _CONSTANTS["images_folder"], self._dataset_type) with open( os.path.join( self._data_folder, _CONSTANTS["questions_folder"], _TEMPLATES["question_json_file"].format(self._dataset_type), ) ) as f: self.questions = json.load(f)[_CONSTANTS["questions_key"]] # Vocab should only be built in main process, as it will repetition of same task if is_main_process(): self._build_vocab(self.questions, _CONSTANTS["question_key"]) self._build_vocab(self.questions, _CONSTANTS["answer_key"]) synchronize()
def _init_process_group(self): training_parameters = self.config.training_parameters self.local_rank = training_parameters.local_rank self.device = training_parameters.device if self.local_rank is not None and training_parameters.distributed: if not torch.distributed.is_nccl_available(): raise RuntimeError( "Unable to initialize process group: NCCL is not available" ) torch.distributed.init_process_group(backend="nccl") synchronize() if ("cuda" in self.device and training_parameters.distributed and self.local_rank is not None): self.device = torch.device("cuda", self.local_rank) registry.register("current_device", self.device)
def build(self, dataset_type, config, *args, **kwargs): """ Similar to load function, used by Pythia to build a dataset for first time when it is not available. This internally calls '_build' function. Override that function in your child class. Args: dataset_type (str): Type of dataset, train|val|test config (ConfigNode): Configuration of this dataset loaded from config. .. warning:: DO NOT OVERRIDE in child class. Instead override ``_build``. """ # Only build in main process, so none of the others have to build if is_main_process(): self._build(dataset_type, config, *args, **kwargs) synchronize()