示例#1
0
    def load(self, **opts):
        self.opts = opts
        self._process_datasets()

        self.datasets = []
        self.builders = []
        available_datasets = self._get_available_datasets()

        self.total_length = 0
        self.per_dataset_lengths = []
        self.num_datasets = 0

        for dataset in self.given_datasets:
            if dataset in available_datasets:
                builder_class = registry.get_builder_class(dataset)

                if builder_class is None:
                    print("No builder class found for %s." % dataset)
                    continue
                builder_instance = builder_class()

                if dataset in self.opts["dataset_attributes"]:
                    attributes = self.opts["dataset_attributes"][dataset]
                else:
                    self.writer.write(
                        "Dataset %s is missing from "
                        "dataset_attributes in config." % dataset,
                        "error",
                    )
                    sys.exit(1)

                dataset_type = self.opts.get("dataset_type", "train")

                if is_main_process():
                    builder_instance.build(dataset_type, attributes)
                synchronize()

                dataset_instance = builder_instance.load(dataset_type, attributes)

                self.builders.append(builder_instance)
                self.datasets.append(dataset_instance)
                self.per_dataset_lengths.append(len(dataset_instance))
                self.total_length += len(dataset_instance)
            else:
                print(
                    "Dataset %s is not a valid dataset for task %s. Skipping"
                    % (dataset, self.task_name)
                )

        self.num_datasets = len(self.datasets)
        self.dataset_probablities = [1 for _ in range(self.num_datasets)]
        sampling = self.opts.get("dataset_size_proportional_sampling", None)

        if sampling is True:
            self.dataset_probablities = self.per_dataset_lengths[:]
            self.dataset_probablities = [
                prob / self.total_length for prob in self.dataset_probablities
            ]

        self.change_dataset()
示例#2
0
    def _try_download(self):
        is_main_process = self._is_main_process()

        if self._already_downloaded:
            return

        if is_main_process:
            self.writer.write("Fetching fastText model for OCR processing")

        needs_download = False

        if not hasattr(self.config, "model_file"):
            if is_main_process:
                warnings.warn("'model_file' key is required but missing "
                              "from FastTextProcessor's config.")
            needs_download = True

        model_file = self.config.model_file
        model_file = os.path.join(get_pythia_root(), model_file)

        if not os.path.exists(model_file):
            if is_main_process:
                warnings.warn(
                    "No model file present at {}.".format(model_file))
            needs_download = True

        if needs_download:
            if is_main_process:
                self.writer.write("Downloading FastText bin", "info")
            model_file = self._download_model()

        synchronize()

        self._load_fasttext_model(model_file)
        self._already_downloaded = True
示例#3
0
    def save(self, iteration, update_best=False):
        # Sync all models before we start the save process
        synchronize()

        # Only save in main process
        if not is_main_process():
            return

        ckpt_filepath = os.path.join(self.models_foldername,
                                     "model_%d.ckpt" % iteration)
        best_ckpt_filepath = os.path.join(self.ckpt_foldername,
                                          self.ckpt_prefix + "best.ckpt")

        best_iteration = self.trainer.early_stopping.best_monitored_iteration
        best_metric = self.trainer.early_stopping.best_monitored_value

        ckpt = {
            "model": self.trainer.model.state_dict(),
            "optimizer": self.trainer.optimizer.state_dict(),
            "best_iteration": best_iteration,
            "best_metric_value": best_metric,
            "config": self.config,
        }

        git_metadata_dict = self._get_vcs_fields()
        ckpt.update(git_metadata_dict)

        torch.save(ckpt, ckpt_filepath)

        if update_best:
            torch.save(ckpt, best_ckpt_filepath)
示例#4
0
    def restore(self):
        synchronize()
        self.trainer.writer.write("Restoring checkpoint")
        best_path = os.path.join(self.ckpt_foldername,
                                 self.ckpt_prefix + "best.ckpt")

        if os.path.exists(best_path):
            self._load(best_path, force=True)
示例#5
0
    def _load(self):
        self.image_path = os.path.join(self._data_folder, _CONSTANTS["images_folder"], self._dataset_type)

        with open(
            os.path.join(
                self._data_folder,
                _CONSTANTS["questions_folder"],
                _TEMPLATES["question_json_file"].format(self._dataset_type),
            )
        ) as f:
            self.questions = json.load(f)[_CONSTANTS["questions_key"]]

            # Vocab should only be built in main process, as it will repetition of same task
            if is_main_process():
                self._build_vocab(self.questions, _CONSTANTS["question_key"])
                self._build_vocab(self.questions, _CONSTANTS["answer_key"])
            synchronize()
示例#6
0
    def _init_process_group(self):
        training_parameters = self.config.training_parameters
        self.local_rank = training_parameters.local_rank
        self.device = training_parameters.device

        if self.local_rank is not None and training_parameters.distributed:
            if not torch.distributed.is_nccl_available():
                raise RuntimeError(
                    "Unable to initialize process group: NCCL is not available"
                )
            torch.distributed.init_process_group(backend="nccl")
            synchronize()

        if ("cuda" in self.device and training_parameters.distributed
                and self.local_rank is not None):
            self.device = torch.device("cuda", self.local_rank)

        registry.register("current_device", self.device)
示例#7
0
    def build(self, dataset_type, config, *args, **kwargs):
        """
        Similar to load function, used by Pythia to build a dataset for first
        time when it is not available. This internally calls '_build' function.
        Override that function in your child class.

        Args:
            dataset_type (str): Type of dataset, train|val|test
            config (ConfigNode): Configuration of this dataset loaded from
                                 config.

        .. warning::

            DO NOT OVERRIDE in child class. Instead override ``_build``.
        """
        # Only build in main process, so none of the others have to build
        if is_main_process():
            self._build(dataset_type, config, *args, **kwargs)
        synchronize()