Beispiel #1
0
    def pretty_print(self):
        if not self.config.training.log_detailed_config:
            return

        self.writer = registry.get("writer")

        self.writer.write("=====  Training Parameters    =====", "info")
        self.writer.write(self._convert_node_to_json(self.config.training),
                          "info")

        self.writer.write("======  Dataset Attributes  ======", "info")
        datasets = self.config.datasets.split(",")

        for dataset in datasets:
            if dataset in self.config.dataset_config:
                self.writer.write("======== {} =======".format(dataset),
                                  "info")
                dataset_config = self.config.dataset_config[dataset]
                self.writer.write(self._convert_node_to_json(dataset_config),
                                  "info")
            else:
                self.writer.write(
                    "No dataset named '{}' in config. Skipping".format(
                        dataset),
                    "warning",
                )

        self.writer.write("======  Optimizer Attributes  ======", "info")
        self.writer.write(self._convert_node_to_json(self.config.optimizer),
                          "info")

        if self.config.model not in self.config.model_config:
            raise ValueError("{} not present in model attributes".format(
                self.config.model))

        self.writer.write(
            "======  Model ({}) Attributes  ======".format(self.config.model),
            "info")
        self.writer.write(
            self._convert_node_to_json(
                self.config.model_config[self.config.model]),
            "info",
        )
Beispiel #2
0
    def _init_extras(self, config, *args, **kwargs):
        self.writer = registry.get("writer")
        self.preprocessor = None

        if hasattr(config, "max_length"):
            self.max_length = config.max_length
        else:
            warnings.warn("No 'max_length' parameter in Processor's "
                          "configuration. Setting to {}.".format(
                              self.MAX_LENGTH_DEFAULT))
            self.max_length = self.MAX_LENGTH_DEFAULT

        if "preprocessor" in config:
            self.preprocessor = Processor(config.preprocessor, *args, **kwargs)

            if self.preprocessor is None:
                raise ValueError(
                    f"No text processor named {config.preprocessor} is defined."
                )
Beispiel #3
0
    def __init__(self,
                 multi_task_instance,
                 test_reporter_config: TestReporterConfigType = None):
        if not isinstance(test_reporter_config,
                          TestReporter.TestReporterConfigType):
            test_reporter_config = TestReporter.TestReporterConfigType(
                **test_reporter_config)
        self.test_task = multi_task_instance
        self.task_type = multi_task_instance.dataset_type
        self.config = registry.get("config")
        self.report = []
        self.timer = Timer()
        self.training_config = self.config.training
        self.num_workers = self.training_config.num_workers
        self.batch_size = self.training_config.batch_size
        self.report_folder_arg = get_mmf_env(key="report_dir")
        self.experiment_name = self.training_config.experiment_name
        self.test_reporter_config = test_reporter_config

        self.datasets = []

        for dataset in self.test_task.get_datasets():
            self.datasets.append(dataset)

        self.current_dataset_idx = -1
        self.current_dataset = self.datasets[self.current_dataset_idx]

        self.save_dir = get_mmf_env(key="save_dir")
        self.report_folder = ckpt_name_from_core_args(self.config)
        self.report_folder += foldername_from_config_override(self.config)

        self.report_folder = os.path.join(self.save_dir, self.report_folder)
        self.report_folder = os.path.join(self.report_folder, "reports")

        if self.report_folder_arg:
            self.report_folder = self.report_folder_arg

        self.candidate_fields = DEFAULT_CANDIDATE_FIELDS

        if not test_reporter_config.candidate_fields == MISSING:
            self.candidate_fields = test_reporter_config.candidate_fields

        PathManager.mkdirs(self.report_folder)
Beispiel #4
0
    def __init__(
        self,
        datamodules: List[pl.LightningDataModule],
        config: Config = None,
        dataset_type: str = "train",
    ):
        self.test_reporter_config = OmegaConf.merge(
            OmegaConf.structured(self.Config), config
        )
        self.datamodules = datamodules
        self.dataset_type = dataset_type
        self.config = registry.get("config")
        self.report = []
        self.timer = Timer()
        self.training_config = self.config.training
        self.num_workers = self.training_config.num_workers
        self.batch_size = self.training_config.batch_size
        self.report_folder_arg = get_mmf_env(key="report_dir")
        self.experiment_name = self.training_config.experiment_name

        self.current_datamodule_idx = -1
        self.dataset_names = list(self.datamodules.keys())
        self.current_datamodule = self.datamodules[
            self.dataset_names[self.current_datamodule_idx]
        ]
        self.current_dataloader = None

        self.save_dir = get_mmf_env(key="save_dir")
        self.report_folder = ckpt_name_from_core_args(self.config)
        self.report_folder += foldername_from_config_override(self.config)

        self.report_folder = os.path.join(self.save_dir, self.report_folder)
        self.report_folder = os.path.join(self.report_folder, "reports")

        if self.report_folder_arg:
            self.report_folder = self.report_folder_arg

        self.candidate_fields = self.test_reporter_config.candidate_fields

        PathManager.mkdirs(self.report_folder)

        log_class_usage("TestReporter", self.__class__)
Beispiel #5
0
    def _load_state_dict_mapping(self, ckpt_model):
        model = self.trainer.model
        attr_mapping = {
            "image_feature_encoders": "img_feat_encoders",
            "image_feature_embeddings_list": "img_embeddings_list",
            "image_text_multi_modal_combine_layer":
            "multi_modal_combine_layer",
            "text_embeddings": "text_embeddings",
            "classifier": "classifier",
        }

        data_parallel = registry.get("data_parallel")

        if not data_parallel:
            for key in attr_mapping:
                attr_mapping[key.replace("module.", "")] = attr_mapping[key]
                attr_mapping.pop(key)

        for key in attr_mapping:
            getattr(model, key).load_state_dict(ckpt_model[attr_mapping[key]])
Beispiel #6
0
    def forward(self, image_feat, embedding):
        image_feat_mean = image_feat.mean(1)

        # Get LSTM state
        state = registry.get(f"{image_feat.device}_lstm_state")
        h1, c1 = state["td_hidden"]
        h2, c2 = state["lm_hidden"]

        h1, c1 = self.top_down_lstm(
            torch.cat([h2, image_feat_mean, embedding], dim=1), (h1, c1))

        state["td_hidden"] = (h1, c1)

        image_fa = self.fa_image(image_feat)
        hidden_fa = self.fa_hidden(h1)

        joint_feature = self.relu(image_fa + hidden_fa.unsqueeze(1))
        joint_feature = self.dropout(joint_feature)

        return joint_feature
Beispiel #7
0
    def __init__(self, trainer):
        """
        Generates a path for saving model which can also be used for resuming
        from a checkpoint.
        """
        self.trainer = trainer

        self.config = self.trainer.config
        self.save_dir = get_mmf_env(key="save_dir")
        self.model_name = self.config.model

        self.ckpt_foldername = self.save_dir

        self.device = registry.get("current_device")

        self.ckpt_prefix = ""

        if hasattr(self.trainer.model, "get_ckpt_name"):
            self.ckpt_prefix = self.trainer.model.get_ckpt_name() + "_"

        self.pth_filepath = os.path.join(
            self.ckpt_foldername,
            self.ckpt_prefix + self.model_name + "_final.pth")

        self.models_foldername = os.path.join(self.ckpt_foldername, "models")
        if not PathManager.exists(self.models_foldername):
            PathManager.mkdirs(self.models_foldername)

        self.save_config()

        self.repo_path = updir(os.path.abspath(__file__), n=3)
        self.git_repo = None
        if git and self.config.checkpoint.save_git_details:
            try:
                self.git_repo = git.Repo(self.repo_path)
            except git.exc.InvalidGitRepositoryError:
                # Not a git repo, don't do anything
                pass

        self.max_to_keep = self.config.checkpoint.max_to_keep
        self.saved_iterations = []
def ready_trainer(trainer):
    from mmf.common.registry import registry
    from mmf.utils.logger import Logger, TensorboardLogger
    trainer.run_type = trainer.config.get("run_type", "train")
    writer = registry.get("writer", no_warning=True)
    if writer:
        trainer.writer = writer
    else:
        trainer.writer = Logger(trainer.config)
        registry.register("writer", trainer.writer)

    trainer.configure_device()
    trainer.configure_seed()
    trainer.load_model()
    from mmf.trainers.callbacks.checkpoint import CheckpointCallback
    from mmf.trainers.callbacks.early_stopping import EarlyStoppingCallback
    trainer.checkpoint_callback = CheckpointCallback(trainer.config, trainer)
    trainer.early_stop_callback = EarlyStoppingCallback(
        trainer.config, trainer)
    trainer.callbacks.append(trainer.checkpoint_callback)
    trainer.on_init_start()
Beispiel #9
0
    def load(self):
        # Set run type
        self.run_type = self.config.get("run_type", "train")

        # Print configuration
        configuration = registry.get("configuration", no_warning=True)
        if configuration:
            configuration.pretty_print()

        # Configure device and cudnn deterministic
        self.configure_device()
        self.configure_seed()

        # Load dataset, model, optimizer and metrics
        self.load_datasets()
        self.load_model()
        self.load_optimizer()
        self.load_metrics()

        # Initialize Callbacks
        self.configure_callbacks()
Beispiel #10
0
def build_processors(processors_config: mmf_typings.DictConfig,
                     registry_key: str = None,
                     *args,
                     **kwargs) -> ProcessorDict:
    """Given a processor config, builds the processors present and returns back
    a dict containing processors mapped to keys as per the config

    Args:
        processors_config (mmf_typings.DictConfig): OmegaConf DictConfig describing
            the parameters and type of each processor passed here

        registry_key (str, optional): If passed, function would look into registry for
            this particular key and return it back. .format with processor_key will
            be called on this string. Defaults to None.

    Returns:
        ProcessorDict: Dictionary containing key to
            processor mapping
    """
    from mmf.datasets.processors.processors import Processor

    processor_dict = {}

    for processor_key, processor_params in processors_config.items():
        if not processor_params:
            continue

        processor_instance = None
        if registry_key is not None:
            full_key = registry_key.format(processor_key)
            processor_instance = registry.get(full_key, no_warning=True)

        if processor_instance is None:
            processor_instance = Processor(processor_params, *args, **kwargs)
            # We don't register back here as in case of hub interface, we
            # want the processors to be instantiate every time. BaseDataset
            # can register at its own end
        processor_dict[processor_key] = processor_instance

    return processor_dict
Beispiel #11
0
    def load(self):
        self._set_device()

        self.run_type = self.config.get("run_type", "train")
        self.dataset_loader = DatasetLoader(self.config)
        self._datasets = self.config.datasets

        # Check if loader is already defined, else init it
        writer = registry.get("writer", no_warning=True)
        if writer:
            self.writer = writer
        else:
            self.writer = Logger(self.config)
            registry.register("writer", self.writer)

        self.configuration.pretty_print()

        self.config_based_setup()

        self.load_datasets()
        self.load_model_and_optimizer()
        self.load_metrics()
Beispiel #12
0
    def __init__(self, config, *args, **kwargs):
        self.writer = registry.get("writer")

        if not hasattr(config, "type"):
            raise AttributeError(
                "Config must have 'type' attribute to specify type of processor"
            )

        processor_class = registry.get_processor_class(config.type)

        params = {}
        if not hasattr(config, "params"):
            warnings.warn("Config doesn't have 'params' attribute to "
                          "specify parameters of the processor "
                          "of type {}. Setting to default {{}}".format(
                              config.type))
        else:
            params = config.params

        self.processor = processor_class(params, *args, **kwargs)

        self._dir_representation = dir(self)
Beispiel #13
0
    def change_dataloader(self):
        if self.num_datasets <= 1:
            return
        choice = 0

        if self._is_master:
            choice = np.random.choice(self.num_datasets,
                                      1,
                                      p=self._dataset_probabilities)[0]

            while choice in self._finished_iterators:
                choice = np.random.choice(self.num_datasets,
                                          1,
                                          p=self._dataset_probabilities)[0]

        choice = broadcast_scalar(choice,
                                  0,
                                  device=registry.get("current_device"))
        self.current_index = choice
        self.current_dataset = self.datasets[self.current_index]
        self.current_loader = self.loaders[self.current_index]
        self._chosen_iterator = self.iterators[self.current_index]
Beispiel #14
0
    def calculate(self, sample_list, model_output, *args, **kwargs):
        answer_processor = registry.get(sample_list.dataset_name +
                                        "_answer_processor")

        batch_size = sample_list.context_tokens.size(0)
        pred_answers = model_output["scores"].argmax(dim=-1)
        context_tokens = sample_list.context_tokens.cpu().numpy()
        answers = sample_list.get(self.gt_key).cpu().numpy()
        answer_space_size = answer_processor.get_true_vocab_size()

        predictions = []
        from mmf.utils.distributed import byte_tensor_to_object
        from mmf.utils.text import word_tokenize

        for idx in range(batch_size):
            tokens = byte_tensor_to_object(context_tokens[idx])
            answer_words = []
            for answer_id in pred_answers[idx].tolist():
                if answer_id >= answer_space_size:
                    answer_id -= answer_space_size
                    answer_words.append(word_tokenize(tokens[answer_id]))
                else:
                    if answer_id == answer_processor.EOS_IDX:
                        break
                    answer_words.append(
                        answer_processor.answer_vocab.idx2word(answer_id))

            pred_answer = " ".join(answer_words).replace(" 's", "'s")
            gt_answers = byte_tensor_to_object(answers[idx])
            predictions.append({
                "pred_answer": pred_answer,
                "gt_answers": gt_answers
            })

        accuracy = self.evaluator.eval_pred_list(predictions)
        accuracy = torch.tensor(accuracy).to(sample_list.context_tokens.device)

        return accuracy
Beispiel #15
0
    def __init__(self, params=None):
        super().__init__()
        if params is None:
            params = {}
        self.writer = registry.get("writer")

        is_mapping = isinstance(params, collections.abc.MutableMapping)

        if is_mapping:
            if "type" not in params:
                raise ValueError("Parameters to loss must have 'type' field to"
                                 "specify type of loss to instantiate")
            else:
                loss_name = params["type"]
        else:
            assert isinstance(
                params,
                str), "loss must be a string or dictionary with 'type' key"
            loss_name = params

        self.name = loss_name

        loss_class = registry.get_loss_class(loss_name)

        if loss_class is None:
            raise ValueError(
                "No loss named {} is registered to registry".format(loss_name))
        # Special case of multi as it requires an array
        if loss_name == "multi":
            assert is_mapping
            self.loss_criterion = loss_class(params)
        else:
            if is_mapping:
                loss_params = params.get("params", {})
            else:
                loss_params = {}
            self.loss_criterion = loss_class(**loss_params)
Beispiel #16
0
    def change_dataloader(self):
        if self.num_datasets <= 1:
            return
        choice = 0

        if self._is_master:
            choice = np.random.choice(self.num_datasets,
                                      1,
                                      p=self._dataset_probabilities)[0]

            # self._finished_iterators will always be empty in case of
            # non-proportional (equal) sampling
            while choice in self._finished_iterators:
                choice = np.random.choice(self.num_datasets,
                                          1,
                                          p=self._dataset_probabilities)[0]

        choice = broadcast_scalar(choice,
                                  0,
                                  device=registry.get("current_device"))
        self.current_index = choice
        self.current_dataset = self.datasets[self.current_index]
        self.current_loader = self.loaders[self.current_index]
        self._chosen_iterator = self.iterators[self.current_index]
Beispiel #17
0
    def __init__(self, config, *args, **kwargs):
        self.writer = registry.get("writer")
        if not hasattr(config, "vocab_file"):
            raise AttributeError(
                "'vocab_file' argument required, but not "
                "present in AnswerProcessor's config"
            )

        self.answer_vocab = VocabDict(config.vocab_file, *args, **kwargs)
        self.PAD_IDX = self.answer_vocab.word2idx("<pad>")
        self.BOS_IDX = self.answer_vocab.word2idx("<s>")
        self.EOS_IDX = self.answer_vocab.word2idx("</s>")
        self.UNK_IDX = self.answer_vocab.UNK_INDEX

        # Set EOS to something not achievable if it is not there
        if self.EOS_IDX == self.UNK_IDX:
            self.EOS_IDX = len(self.answer_vocab)

        self.preprocessor = None

        if hasattr(config, "preprocessor"):
            self.preprocessor = Processor(config.preprocessor)

            if self.preprocessor is None:
                raise ValueError(
                    f"No processor named {config.preprocessor} is defined."
                )

        if hasattr(config, "num_answers"):
            self.num_answers = config.num_answers
        else:
            self.num_answers = self.DEFAULT_NUM_ANSWERS
            warnings.warn(
                "'num_answers' not defined in the config. "
                "Setting to default of {}".format(self.DEFAULT_NUM_ANSWERS)
            )
Beispiel #18
0
    def __init__(self, config, *args, **kwargs):
        self.writer = registry.get("writer")
        if not hasattr(config, "vocab_file"):
            raise AttributeError("'vocab_file' argument required, but not "
                                 "present in AnswerProcessor's config")

        self.answer_vocab = VocabDict(config.vocab_file, *args, **kwargs)

        self.preprocessor = None

        if hasattr(config, "preprocessor"):
            self.preprocessor = Processor(config.preprocessor)

            if self.preprocessor is None:
                raise ValueError(
                    f"No processor named {config.preprocessor} is defined.")

        if hasattr(config, "num_answers"):
            self.num_answers = config.num_answers
        else:
            self.num_answers = self.DEFAULT_NUM_ANSWERS
            warnings.warn("'num_answers' not defined in the config. "
                          "Setting to default of {}".format(
                              self.DEFAULT_NUM_ANSWERS))
Beispiel #19
0
 def __init__(self, config):
     super().__init__(config)
     self.config = config
     self._global_config = registry.get("config")
     self._datasets = self._global_config.datasets.split(",")
Beispiel #20
0
    def save(self, update, iteration=None, update_best=False):
        # Only save in main process
        if not is_master():
            return

        if not iteration:
            iteration = update

        ckpt_filepath = os.path.join(self.models_foldername,
                                     "model_%d.ckpt" % update)
        best_ckpt_filepath = os.path.join(self.ckpt_foldername,
                                          self.ckpt_prefix + "best.ckpt")
        current_ckpt_filepath = os.path.join(self.ckpt_foldername,
                                             self.ckpt_prefix + "current.ckpt")

        best_iteration = (self.trainer.early_stop_callback.early_stopping.
                          best_monitored_iteration)
        best_update = (self.trainer.early_stop_callback.early_stopping.
                       best_monitored_update)
        best_metric = (self.trainer.early_stop_callback.early_stopping.
                       best_monitored_value)
        model = self.trainer.model
        data_parallel = registry.get("data_parallel") or registry.get(
            "distributed")

        if data_parallel is True:
            model = model.module

        ckpt = {
            "model": model.state_dict(),
            "optimizer": self.trainer.optimizer.state_dict(),
            "best_iteration": best_iteration,
            "current_iteration": iteration,
            "current_epoch": self.trainer.current_epoch,
            "num_updates": update,
            "best_update": best_update,
            "best_metric_value": best_metric,
            # Convert to container to avoid any dependencies
            "config": OmegaConf.to_container(self.config, resolve=True),
        }

        lr_scheduler = self.trainer.lr_scheduler_callback._scheduler
        if lr_scheduler is not None:
            ckpt["lr_scheduler"] = lr_scheduler.state_dict()

        if self.git_repo:
            git_metadata_dict = self._get_vcs_fields()
            ckpt.update(git_metadata_dict)

        with PathManager.open(ckpt_filepath, "wb") as f:
            torch.save(ckpt, f)

        if update_best:
            with PathManager.open(best_ckpt_filepath, "wb") as f:
                torch.save(ckpt, f)

        # Save current always
        with PathManager.open(current_ckpt_filepath, "wb") as f:
            torch.save(ckpt, f)

        # Remove old checkpoints if max_to_keep is set
        if self.max_to_keep > 0:
            if len(self.saved_iterations) == self.max_to_keep:
                self.remove(self.saved_iterations.pop(0))
            self.saved_iterations.append(update)
Beispiel #21
0
 def _init_classifier(self):
     num_hidden = self.config.text_embedding.num_hidden
     num_choices = registry.get(self._datasets[0] + "_num_final_outputs")
     dropout = self.config.classifier.dropout
     self.classifier = WeightNormClassifier(num_hidden, num_choices,
                                            num_hidden * 2, dropout)
Beispiel #22
0
 def __init__(self, loss_list):
     super().__init__()
     self.losses = []
     self._evaluation_predict = registry.get("config").evaluation.predict
     for loss in loss_list:
         self.losses.append(MMFLoss(loss))
Beispiel #23
0
def is_xla():
    return registry.get("is_xla", no_warning=True)
Beispiel #24
0
def setup_logger(
    output: str = None,
    color: bool = True,
    name: str = "mmf",
    disable: bool = False,
    clear_handlers=True,
    *args,
    **kwargs,
):
    """
    Initialize the MMF logger and set its verbosity level to "INFO".
    Outside libraries shouldn't call this in case they have set there
    own logging handlers and setup. If they do, and don't want to
    clear handlers, pass clear_handlers options.

    The initial version of this function was taken from D2 and adapted
    for MMF.

    Args:
        output (str): a file name or a directory to save log.
            If ends with ".txt" or ".log", assumed to be a file name.
            Default: Saved to file <save_dir/logs/log_[timestamp].txt>
        color (bool): If false, won't log colored logs. Default: true
        name (str): the root module name of this logger. Defaults to "mmf".
        clear_handlers (bool): If false, won't clear existing handlers.

    Returns:
        logging.Logger: a logger
    """
    if disable:
        return None
    logger = logging.getLogger(name)
    logger.propagate = False

    logging.captureWarnings(True)
    warnings_logger = logging.getLogger("py.warnings")

    plain_formatter = logging.Formatter(
        "%(asctime)s | %(levelname)s | %(name)s : %(message)s",
        datefmt="%Y-%m-%dT%H:%M:%S",
    )

    distributed_rank = get_rank()
    handlers = []

    logging_level = registry.get("config").training.logger_level.upper()
    if distributed_rank == 0:
        logger.setLevel(logging_level)
        ch = logging.StreamHandler(stream=sys.stdout)
        ch.setLevel(logging_level)
        if color:
            formatter = ColorfulFormatter(
                colored("%(asctime)s | %(name)s: ", "green") + "%(message)s",
                datefmt="%Y-%m-%dT%H:%M:%S",
            )
        else:
            formatter = plain_formatter
        ch.setFormatter(formatter)
        logger.addHandler(ch)
        warnings_logger.addHandler(ch)
        handlers.append(ch)

    # file logging: all workers
    if output is None:
        output = setup_output_folder()

    if output is not None:
        if output.endswith(".txt") or output.endswith(".log"):
            filename = output
        else:
            filename = os.path.join(output, "train.log")
        if distributed_rank > 0:
            filename = filename + f".rank{distributed_rank}"
        PathManager.mkdirs(os.path.dirname(filename))

        fh = logging.StreamHandler(_cached_log_stream(filename))
        fh.setLevel(logging_level)
        fh.setFormatter(plain_formatter)
        logger.addHandler(fh)
        warnings_logger.addHandler(fh)
        handlers.append(fh)

        # Slurm/FB output, only log the main process
        if "train.log" not in filename and distributed_rank == 0:
            save_dir = get_mmf_env(key="save_dir")
            filename = os.path.join(save_dir, "train.log")
            sh = logging.StreamHandler(_cached_log_stream(filename))
            sh.setLevel(logging_level)
            sh.setFormatter(plain_formatter)
            logger.addHandler(sh)
            warnings_logger.addHandler(sh)
            handlers.append(sh)

        logger.info(f"Logging to: {filename}")

    # Remove existing handlers to add MMF specific handlers
    if clear_handlers:
        for handler in logging.root.handlers[:]:
            logging.root.removeHandler(handler)
    # Now, add our handlers.
    logging.basicConfig(level=logging_level, handlers=handlers)

    registry.register("writer", logger)

    return logger
Beispiel #25
0
def import_user_module(user_dir: str):
    """Given a user dir, this function imports it as a module.

    This user_module is expected to have an __init__.py at its root.
    You can use import_files to import your python files easily in
    __init__.py

    Args:
        user_dir (str): directory which has to be imported
    """
    from mmf.common.registry import registry
    from mmf.utils.general import get_absolute_path  # noqa

    logger = logging.getLogger(__name__)
    if user_dir:
        if registry.get("__mmf_user_dir_imported__", no_warning=True):
            logger.info(f"User dir {user_dir} already imported. Skipping.")
            return

        # Allow loading of files as user source
        if user_dir.endswith(".py"):
            user_dir = user_dir[:-3]

        dot_path = ".".join(user_dir.split(os.path.sep))
        # In case of abspath which start from "/" the first char
        # will be "." which turns it into relative module which
        # find_spec doesn't like
        if os.path.isabs(user_dir):
            dot_path = dot_path[1:]

        try:
            dot_spec = importlib.util.find_spec(dot_path)
        except ModuleNotFoundError:
            dot_spec = None
        abs_user_dir = get_absolute_path(user_dir)
        module_parent, module_name = os.path.split(abs_user_dir)

        # If dot path is found in sys.modules, or path can be directly
        # be imported, we don't need to play jugglery with actual path
        if dot_path in sys.modules or dot_spec is not None:
            module_name = dot_path
        else:
            user_dir = abs_user_dir

        logger.info(f"Importing from {user_dir}")
        if module_name != dot_path:
            # Since dot path hasn't been found or can't be imported,
            # we can try importing the module by changing sys path
            # to the parent
            sys.path.insert(0, module_parent)

        importlib.import_module(module_name)
        sys.modules["mmf_user_dir"] = sys.modules[module_name]

        # Register config for user's model and dataset config
        # relative path resolution
        config = registry.get("config")
        if config is None:
            registry.register(
                "config", OmegaConf.create({"env": {"user_dir": user_dir}})
            )
        else:
            with open_dict(config):
                config.env.user_dir = user_dir

        registry.register("__mmf_user_dir_imported__", True)
Beispiel #26
0
    def save(self, update, iteration=None, update_best=False):
        # Only save in main process
        # For xla we use xm.save method
        # Which ensures that actual checkpoint saving happens
        # only for the master node.
        # The method also takes care of all the necessary synchronization
        if not is_master() and not is_xla():
            return

        logger.info("Checkpoint save operation started!")
        if not iteration:
            iteration = update

        ckpt_filepath = os.path.join(self.models_foldername, "model_%d.ckpt" % update)
        best_ckpt_filepath = os.path.join(
            self.ckpt_foldername, self.ckpt_prefix + "best.ckpt"
        )
        current_ckpt_filepath = os.path.join(
            self.ckpt_foldername, self.ckpt_prefix + "current.ckpt"
        )

        best_iteration = (
            self.trainer.early_stop_callback.early_stopping.best_monitored_iteration
        )
        best_update = (
            self.trainer.early_stop_callback.early_stopping.best_monitored_update
        )
        best_metric = (
            self.trainer.early_stop_callback.early_stopping.best_monitored_value
        )
        model = self.trainer.model
        data_parallel = registry.get("data_parallel") or registry.get("distributed")
        fp16_scaler = getattr(self.trainer, "scaler", None)
        fp16_scaler_dict = None

        if fp16_scaler is not None:
            fp16_scaler_dict = fp16_scaler.state_dict()

        if data_parallel is True:
            model = model.module

        ckpt = {
            "model": model.state_dict(),
            "optimizer": self.trainer.optimizer.state_dict(),
            "best_iteration": best_iteration,
            "current_iteration": iteration,
            "current_epoch": self.trainer.current_epoch,
            "num_updates": update,
            "best_update": best_update,
            "best_metric_value": best_metric,
            "fp16_scaler": fp16_scaler_dict,
            # Convert to container to avoid any dependencies
            "config": OmegaConf.to_container(self.config, resolve=True),
        }

        lr_scheduler = self.trainer.lr_scheduler_callback._scheduler
        if lr_scheduler is not None:
            ckpt["lr_scheduler"] = lr_scheduler.state_dict()

        if self.git_repo:
            git_metadata_dict = self._get_vcs_fields()
            ckpt.update(git_metadata_dict)

        with PathManager.open(ckpt_filepath, "wb") as f:
            self.save_func(ckpt, f)

        if update_best:
            logger.info("Saving best checkpoint")
            with PathManager.open(best_ckpt_filepath, "wb") as f:
                self.save_func(ckpt, f)

        # Save current always

        logger.info("Saving current checkpoint")
        with PathManager.open(current_ckpt_filepath, "wb") as f:
            self.save_func(ckpt, f)

        # Remove old checkpoints if max_to_keep is set
        if self.max_to_keep > 0:
            if len(self.saved_iterations) == self.max_to_keep:
                self.remove(self.saved_iterations.pop(0))
            self.saved_iterations.append(update)

        logger.info("Checkpoint save operation finished!")
Beispiel #27
0
 def __init__(self, config):
     super().__init__(config)
     self.mmt_config = BertConfig(**self.config.mmt)
     self._datasets = registry.get("config").datasets.split(",")
    def __init__(
        self,
        num_train_data,
        max_updates,
        max_epochs,
        config=None,
        optimizer=None,
        update_frequency=1,
        batch_size=1,
        batch_size_per_device=None,
        fp16=False,
        on_update_end_fn=None,
        scheduler_config=None,
        grad_clipping_config=None,
    ):
        if config is None:
            self.config = OmegaConf.create(
                {
                    "training": {
                        "detect_anomaly": False,
                        "evaluation_interval": 10000,
                        "update_frequency": update_frequency,
                        "fp16": fp16,
                        "batch_size": batch_size,
                        "batch_size_per_device": batch_size_per_device,
                    }
                }
            )
            self.training_config = self.config.training
        else:
            self.training_config = config.training
            self.config = config

        # Load batch size with custom config and cleanup
        original_config = registry.get("config")
        registry.register("config", self.config)
        batch_size = get_batch_size()
        registry.register("config", original_config)

        if max_updates is not None:
            self.training_config["max_updates"] = max_updates
        if max_epochs is not None:
            self.training_config["max_epochs"] = max_epochs

        self.model = SimpleModel({"in_dim": 1})
        self.model.build()
        if torch.cuda.is_available():
            self.model = self.model.cuda()
            self.device = "cuda"
        else:
            self.device = "cpu"
        self.distributed = False

        self.dataset_loader = MagicMock()
        self.dataset_loader.seed_sampler = MagicMock(return_value=None)
        self.dataset_loader.prepare_batch = lambda x: SampleList(x)
        if optimizer is None:
            self.optimizer = MagicMock()
            self.optimizer.step = MagicMock(return_value=None)
            self.optimizer.zero_grad = MagicMock(return_value=None)
        else:
            self.optimizer = optimizer

        if scheduler_config:
            config.training.lr_scheduler = True
            config.scheduler = scheduler_config
            self.lr_scheduler_callback = LRSchedulerCallback(config, self)
            self.callbacks.append(self.lr_scheduler_callback)
            on_update_end_fn = (
                on_update_end_fn
                if on_update_end_fn
                else self.lr_scheduler_callback.on_update_end
            )

        if grad_clipping_config:
            self.training_config.clip_gradients = True
            self.training_config.max_grad_l2_norm = grad_clipping_config[
                "max_grad_l2_norm"
            ]
            self.training_config.clip_norm_mode = grad_clipping_config["clip_norm_mode"]

        dataset = NumbersDataset(num_train_data)
        self.train_loader = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=1,
            drop_last=False,
        )
        self.train_loader.current_dataset = dataset
        self.on_batch_start = MagicMock(return_value=None)
        self.on_update_start = MagicMock(return_value=None)
        self.logistics_callback = MagicMock(return_value=None)
        self.logistics_callback.log_interval = MagicMock(return_value=None)
        self.on_batch_end = MagicMock(return_value=None)
        self.on_update_end = (
            on_update_end_fn if on_update_end_fn else MagicMock(return_value=None)
        )
        self.meter = Meter()
        self.after_training_loop = MagicMock(return_value=None)
        self.on_validation_start = MagicMock(return_value=None)
        self.evaluation_loop = MagicMock(return_value=(None, None))
        self.scaler = torch.cuda.amp.GradScaler(enabled=False)
        self.val_loader = MagicMock(return_value=None)
        self.early_stop_callback = MagicMock(return_value=None)
        self.on_validation_end = MagicMock(return_value=None)
        self.metrics = MagicMock(return_value=None)
Beispiel #29
0
 def _build_word_embedding(self):
     assert len(self._datasets) > 0
     text_processor = registry.get(self._datasets[0] + "_text_processor")
     vocab = text_processor.vocab
     self.word_embedding = vocab.get_embedding(torch.nn.Embedding,
                                               embedding_dim=300)
Beispiel #30
0
    def _load(self, file, force=False, load_pretrained=False):
        tp = self.config.training
        self.trainer.writer.write("Loading checkpoint")

        ckpt = self._torch_load(file)

        data_parallel = registry.get("data_parallel") or registry.get(
            "distributed")

        if "model" in ckpt:
            ckpt_model = ckpt["model"]
        else:
            ckpt_model = ckpt
            ckpt = {"model": ckpt}

        pretrained_mapping = tp.pretrained_mapping

        if load_pretrained is False or force is True:
            pretrained_mapping = {}

        new_dict = {}

        # TODO: Move to separate function
        for attr in ckpt_model:
            new_attr = attr
            if "fa_history" in attr:
                new_attr = new_attr.replace("fa_history", "fa_context")

            if data_parallel is False and attr.startswith("module."):
                # In case the ckpt was actually a data parallel model
                # replace first module. from dataparallel with empty string
                new_dict[new_attr.replace("module.", "", 1)] = ckpt_model[attr]
            elif data_parallel is not False and not attr.startswith("module."):
                new_dict["module." + new_attr] = ckpt_model[attr]
            else:
                new_dict[new_attr] = ckpt_model[attr]

        if len(pretrained_mapping.items()) == 0:
            final_dict = new_dict
            self.trainer.model.load_state_dict(final_dict, strict=False)

            if "optimizer" in ckpt:
                self.trainer.optimizer.load_state_dict(ckpt["optimizer"])
            else:
                warnings.warn("'optimizer' key is not present in the "
                              "checkpoint asked to be loaded. Skipping.")

            self.trainer.early_stopping.init_from_checkpoint(ckpt)

            self.trainer.writer.write("Checkpoint loaded")

            if "best_update" in ckpt:
                if tp.resume_best:
                    self.trainer.num_updates = ckpt.get(
                        "best_update", self.trainer.num_updates)
                    self.trainer.current_iteration = ckpt.get(
                        "best_iteration", self.trainer.current_iteration)
                else:
                    self.trainer.num_updates = ckpt.get(
                        "num_updates", self.trainer.num_updates)
                    self.trainer.current_iteration = ckpt.get(
                        "current_iteration", self.trainer.current_iteration)

                self.trainer.current_epoch = ckpt.get(
                    "current_epoch", self.trainer.current_epoch)
            elif "best_iteration" in ckpt:
                # Preserve old behavior for old checkpoints where we always
                # load best iteration
                if tp.resume_best and "current_iteration" in ckpt:
                    self.trainer.current_iteration = ckpt["current_iteration"]
                else:
                    self.trainer.current_iteration = ckpt.get(
                        "best_iteration", self.trainer.current_iteration)

                self.trainer.num_updates = self.trainer.current_iteration

            registry.register("current_iteration",
                              self.trainer.current_iteration)
            registry.register("num_updates", self.trainer.num_updates)

            self.trainer.current_epoch = ckpt.get("best_epoch",
                                                  self.trainer.current_epoch)
            registry.register("current_epoch", self.trainer.current_epoch)
        else:
            final_dict = {}
            model = self.trainer.model
            own_state = model.state_dict()

            for key, value in pretrained_mapping.items():
                key += "."
                value += "."
                for attr in new_dict:
                    for own_attr in own_state:
                        formatted_attr = model.format_state_key(attr)
                        if (key in formatted_attr and value in own_attr
                                and formatted_attr.replace(
                                    key, "") == own_attr.replace(value, "")):
                            self.trainer.writer.write("Copying " + attr + " " +
                                                      own_attr)
                            own_state[own_attr].copy_(new_dict[attr])
            self.trainer.writer.write("Pretrained model loaded")