def pretty_print(self): if not self.config.training.log_detailed_config: return self.writer = registry.get("writer") self.writer.write("===== Training Parameters =====", "info") self.writer.write(self._convert_node_to_json(self.config.training), "info") self.writer.write("====== Dataset Attributes ======", "info") datasets = self.config.datasets.split(",") for dataset in datasets: if dataset in self.config.dataset_config: self.writer.write("======== {} =======".format(dataset), "info") dataset_config = self.config.dataset_config[dataset] self.writer.write(self._convert_node_to_json(dataset_config), "info") else: self.writer.write( "No dataset named '{}' in config. Skipping".format( dataset), "warning", ) self.writer.write("====== Optimizer Attributes ======", "info") self.writer.write(self._convert_node_to_json(self.config.optimizer), "info") if self.config.model not in self.config.model_config: raise ValueError("{} not present in model attributes".format( self.config.model)) self.writer.write( "====== Model ({}) Attributes ======".format(self.config.model), "info") self.writer.write( self._convert_node_to_json( self.config.model_config[self.config.model]), "info", )
def _init_extras(self, config, *args, **kwargs): self.writer = registry.get("writer") self.preprocessor = None if hasattr(config, "max_length"): self.max_length = config.max_length else: warnings.warn("No 'max_length' parameter in Processor's " "configuration. Setting to {}.".format( self.MAX_LENGTH_DEFAULT)) self.max_length = self.MAX_LENGTH_DEFAULT if "preprocessor" in config: self.preprocessor = Processor(config.preprocessor, *args, **kwargs) if self.preprocessor is None: raise ValueError( f"No text processor named {config.preprocessor} is defined." )
def __init__(self, multi_task_instance, test_reporter_config: TestReporterConfigType = None): if not isinstance(test_reporter_config, TestReporter.TestReporterConfigType): test_reporter_config = TestReporter.TestReporterConfigType( **test_reporter_config) self.test_task = multi_task_instance self.task_type = multi_task_instance.dataset_type self.config = registry.get("config") self.report = [] self.timer = Timer() self.training_config = self.config.training self.num_workers = self.training_config.num_workers self.batch_size = self.training_config.batch_size self.report_folder_arg = get_mmf_env(key="report_dir") self.experiment_name = self.training_config.experiment_name self.test_reporter_config = test_reporter_config self.datasets = [] for dataset in self.test_task.get_datasets(): self.datasets.append(dataset) self.current_dataset_idx = -1 self.current_dataset = self.datasets[self.current_dataset_idx] self.save_dir = get_mmf_env(key="save_dir") self.report_folder = ckpt_name_from_core_args(self.config) self.report_folder += foldername_from_config_override(self.config) self.report_folder = os.path.join(self.save_dir, self.report_folder) self.report_folder = os.path.join(self.report_folder, "reports") if self.report_folder_arg: self.report_folder = self.report_folder_arg self.candidate_fields = DEFAULT_CANDIDATE_FIELDS if not test_reporter_config.candidate_fields == MISSING: self.candidate_fields = test_reporter_config.candidate_fields PathManager.mkdirs(self.report_folder)
def __init__( self, datamodules: List[pl.LightningDataModule], config: Config = None, dataset_type: str = "train", ): self.test_reporter_config = OmegaConf.merge( OmegaConf.structured(self.Config), config ) self.datamodules = datamodules self.dataset_type = dataset_type self.config = registry.get("config") self.report = [] self.timer = Timer() self.training_config = self.config.training self.num_workers = self.training_config.num_workers self.batch_size = self.training_config.batch_size self.report_folder_arg = get_mmf_env(key="report_dir") self.experiment_name = self.training_config.experiment_name self.current_datamodule_idx = -1 self.dataset_names = list(self.datamodules.keys()) self.current_datamodule = self.datamodules[ self.dataset_names[self.current_datamodule_idx] ] self.current_dataloader = None self.save_dir = get_mmf_env(key="save_dir") self.report_folder = ckpt_name_from_core_args(self.config) self.report_folder += foldername_from_config_override(self.config) self.report_folder = os.path.join(self.save_dir, self.report_folder) self.report_folder = os.path.join(self.report_folder, "reports") if self.report_folder_arg: self.report_folder = self.report_folder_arg self.candidate_fields = self.test_reporter_config.candidate_fields PathManager.mkdirs(self.report_folder) log_class_usage("TestReporter", self.__class__)
def _load_state_dict_mapping(self, ckpt_model): model = self.trainer.model attr_mapping = { "image_feature_encoders": "img_feat_encoders", "image_feature_embeddings_list": "img_embeddings_list", "image_text_multi_modal_combine_layer": "multi_modal_combine_layer", "text_embeddings": "text_embeddings", "classifier": "classifier", } data_parallel = registry.get("data_parallel") if not data_parallel: for key in attr_mapping: attr_mapping[key.replace("module.", "")] = attr_mapping[key] attr_mapping.pop(key) for key in attr_mapping: getattr(model, key).load_state_dict(ckpt_model[attr_mapping[key]])
def forward(self, image_feat, embedding): image_feat_mean = image_feat.mean(1) # Get LSTM state state = registry.get(f"{image_feat.device}_lstm_state") h1, c1 = state["td_hidden"] h2, c2 = state["lm_hidden"] h1, c1 = self.top_down_lstm( torch.cat([h2, image_feat_mean, embedding], dim=1), (h1, c1)) state["td_hidden"] = (h1, c1) image_fa = self.fa_image(image_feat) hidden_fa = self.fa_hidden(h1) joint_feature = self.relu(image_fa + hidden_fa.unsqueeze(1)) joint_feature = self.dropout(joint_feature) return joint_feature
def __init__(self, trainer): """ Generates a path for saving model which can also be used for resuming from a checkpoint. """ self.trainer = trainer self.config = self.trainer.config self.save_dir = get_mmf_env(key="save_dir") self.model_name = self.config.model self.ckpt_foldername = self.save_dir self.device = registry.get("current_device") self.ckpt_prefix = "" if hasattr(self.trainer.model, "get_ckpt_name"): self.ckpt_prefix = self.trainer.model.get_ckpt_name() + "_" self.pth_filepath = os.path.join( self.ckpt_foldername, self.ckpt_prefix + self.model_name + "_final.pth") self.models_foldername = os.path.join(self.ckpt_foldername, "models") if not PathManager.exists(self.models_foldername): PathManager.mkdirs(self.models_foldername) self.save_config() self.repo_path = updir(os.path.abspath(__file__), n=3) self.git_repo = None if git and self.config.checkpoint.save_git_details: try: self.git_repo = git.Repo(self.repo_path) except git.exc.InvalidGitRepositoryError: # Not a git repo, don't do anything pass self.max_to_keep = self.config.checkpoint.max_to_keep self.saved_iterations = []
def ready_trainer(trainer): from mmf.common.registry import registry from mmf.utils.logger import Logger, TensorboardLogger trainer.run_type = trainer.config.get("run_type", "train") writer = registry.get("writer", no_warning=True) if writer: trainer.writer = writer else: trainer.writer = Logger(trainer.config) registry.register("writer", trainer.writer) trainer.configure_device() trainer.configure_seed() trainer.load_model() from mmf.trainers.callbacks.checkpoint import CheckpointCallback from mmf.trainers.callbacks.early_stopping import EarlyStoppingCallback trainer.checkpoint_callback = CheckpointCallback(trainer.config, trainer) trainer.early_stop_callback = EarlyStoppingCallback( trainer.config, trainer) trainer.callbacks.append(trainer.checkpoint_callback) trainer.on_init_start()
def load(self): # Set run type self.run_type = self.config.get("run_type", "train") # Print configuration configuration = registry.get("configuration", no_warning=True) if configuration: configuration.pretty_print() # Configure device and cudnn deterministic self.configure_device() self.configure_seed() # Load dataset, model, optimizer and metrics self.load_datasets() self.load_model() self.load_optimizer() self.load_metrics() # Initialize Callbacks self.configure_callbacks()
def build_processors(processors_config: mmf_typings.DictConfig, registry_key: str = None, *args, **kwargs) -> ProcessorDict: """Given a processor config, builds the processors present and returns back a dict containing processors mapped to keys as per the config Args: processors_config (mmf_typings.DictConfig): OmegaConf DictConfig describing the parameters and type of each processor passed here registry_key (str, optional): If passed, function would look into registry for this particular key and return it back. .format with processor_key will be called on this string. Defaults to None. Returns: ProcessorDict: Dictionary containing key to processor mapping """ from mmf.datasets.processors.processors import Processor processor_dict = {} for processor_key, processor_params in processors_config.items(): if not processor_params: continue processor_instance = None if registry_key is not None: full_key = registry_key.format(processor_key) processor_instance = registry.get(full_key, no_warning=True) if processor_instance is None: processor_instance = Processor(processor_params, *args, **kwargs) # We don't register back here as in case of hub interface, we # want the processors to be instantiate every time. BaseDataset # can register at its own end processor_dict[processor_key] = processor_instance return processor_dict
def load(self): self._set_device() self.run_type = self.config.get("run_type", "train") self.dataset_loader = DatasetLoader(self.config) self._datasets = self.config.datasets # Check if loader is already defined, else init it writer = registry.get("writer", no_warning=True) if writer: self.writer = writer else: self.writer = Logger(self.config) registry.register("writer", self.writer) self.configuration.pretty_print() self.config_based_setup() self.load_datasets() self.load_model_and_optimizer() self.load_metrics()
def __init__(self, config, *args, **kwargs): self.writer = registry.get("writer") if not hasattr(config, "type"): raise AttributeError( "Config must have 'type' attribute to specify type of processor" ) processor_class = registry.get_processor_class(config.type) params = {} if not hasattr(config, "params"): warnings.warn("Config doesn't have 'params' attribute to " "specify parameters of the processor " "of type {}. Setting to default {{}}".format( config.type)) else: params = config.params self.processor = processor_class(params, *args, **kwargs) self._dir_representation = dir(self)
def change_dataloader(self): if self.num_datasets <= 1: return choice = 0 if self._is_master: choice = np.random.choice(self.num_datasets, 1, p=self._dataset_probabilities)[0] while choice in self._finished_iterators: choice = np.random.choice(self.num_datasets, 1, p=self._dataset_probabilities)[0] choice = broadcast_scalar(choice, 0, device=registry.get("current_device")) self.current_index = choice self.current_dataset = self.datasets[self.current_index] self.current_loader = self.loaders[self.current_index] self._chosen_iterator = self.iterators[self.current_index]
def calculate(self, sample_list, model_output, *args, **kwargs): answer_processor = registry.get(sample_list.dataset_name + "_answer_processor") batch_size = sample_list.context_tokens.size(0) pred_answers = model_output["scores"].argmax(dim=-1) context_tokens = sample_list.context_tokens.cpu().numpy() answers = sample_list.get(self.gt_key).cpu().numpy() answer_space_size = answer_processor.get_true_vocab_size() predictions = [] from mmf.utils.distributed import byte_tensor_to_object from mmf.utils.text import word_tokenize for idx in range(batch_size): tokens = byte_tensor_to_object(context_tokens[idx]) answer_words = [] for answer_id in pred_answers[idx].tolist(): if answer_id >= answer_space_size: answer_id -= answer_space_size answer_words.append(word_tokenize(tokens[answer_id])) else: if answer_id == answer_processor.EOS_IDX: break answer_words.append( answer_processor.answer_vocab.idx2word(answer_id)) pred_answer = " ".join(answer_words).replace(" 's", "'s") gt_answers = byte_tensor_to_object(answers[idx]) predictions.append({ "pred_answer": pred_answer, "gt_answers": gt_answers }) accuracy = self.evaluator.eval_pred_list(predictions) accuracy = torch.tensor(accuracy).to(sample_list.context_tokens.device) return accuracy
def __init__(self, params=None): super().__init__() if params is None: params = {} self.writer = registry.get("writer") is_mapping = isinstance(params, collections.abc.MutableMapping) if is_mapping: if "type" not in params: raise ValueError("Parameters to loss must have 'type' field to" "specify type of loss to instantiate") else: loss_name = params["type"] else: assert isinstance( params, str), "loss must be a string or dictionary with 'type' key" loss_name = params self.name = loss_name loss_class = registry.get_loss_class(loss_name) if loss_class is None: raise ValueError( "No loss named {} is registered to registry".format(loss_name)) # Special case of multi as it requires an array if loss_name == "multi": assert is_mapping self.loss_criterion = loss_class(params) else: if is_mapping: loss_params = params.get("params", {}) else: loss_params = {} self.loss_criterion = loss_class(**loss_params)
def change_dataloader(self): if self.num_datasets <= 1: return choice = 0 if self._is_master: choice = np.random.choice(self.num_datasets, 1, p=self._dataset_probabilities)[0] # self._finished_iterators will always be empty in case of # non-proportional (equal) sampling while choice in self._finished_iterators: choice = np.random.choice(self.num_datasets, 1, p=self._dataset_probabilities)[0] choice = broadcast_scalar(choice, 0, device=registry.get("current_device")) self.current_index = choice self.current_dataset = self.datasets[self.current_index] self.current_loader = self.loaders[self.current_index] self._chosen_iterator = self.iterators[self.current_index]
def __init__(self, config, *args, **kwargs): self.writer = registry.get("writer") if not hasattr(config, "vocab_file"): raise AttributeError( "'vocab_file' argument required, but not " "present in AnswerProcessor's config" ) self.answer_vocab = VocabDict(config.vocab_file, *args, **kwargs) self.PAD_IDX = self.answer_vocab.word2idx("<pad>") self.BOS_IDX = self.answer_vocab.word2idx("<s>") self.EOS_IDX = self.answer_vocab.word2idx("</s>") self.UNK_IDX = self.answer_vocab.UNK_INDEX # Set EOS to something not achievable if it is not there if self.EOS_IDX == self.UNK_IDX: self.EOS_IDX = len(self.answer_vocab) self.preprocessor = None if hasattr(config, "preprocessor"): self.preprocessor = Processor(config.preprocessor) if self.preprocessor is None: raise ValueError( f"No processor named {config.preprocessor} is defined." ) if hasattr(config, "num_answers"): self.num_answers = config.num_answers else: self.num_answers = self.DEFAULT_NUM_ANSWERS warnings.warn( "'num_answers' not defined in the config. " "Setting to default of {}".format(self.DEFAULT_NUM_ANSWERS) )
def __init__(self, config, *args, **kwargs): self.writer = registry.get("writer") if not hasattr(config, "vocab_file"): raise AttributeError("'vocab_file' argument required, but not " "present in AnswerProcessor's config") self.answer_vocab = VocabDict(config.vocab_file, *args, **kwargs) self.preprocessor = None if hasattr(config, "preprocessor"): self.preprocessor = Processor(config.preprocessor) if self.preprocessor is None: raise ValueError( f"No processor named {config.preprocessor} is defined.") if hasattr(config, "num_answers"): self.num_answers = config.num_answers else: self.num_answers = self.DEFAULT_NUM_ANSWERS warnings.warn("'num_answers' not defined in the config. " "Setting to default of {}".format( self.DEFAULT_NUM_ANSWERS))
def __init__(self, config): super().__init__(config) self.config = config self._global_config = registry.get("config") self._datasets = self._global_config.datasets.split(",")
def save(self, update, iteration=None, update_best=False): # Only save in main process if not is_master(): return if not iteration: iteration = update ckpt_filepath = os.path.join(self.models_foldername, "model_%d.ckpt" % update) best_ckpt_filepath = os.path.join(self.ckpt_foldername, self.ckpt_prefix + "best.ckpt") current_ckpt_filepath = os.path.join(self.ckpt_foldername, self.ckpt_prefix + "current.ckpt") best_iteration = (self.trainer.early_stop_callback.early_stopping. best_monitored_iteration) best_update = (self.trainer.early_stop_callback.early_stopping. best_monitored_update) best_metric = (self.trainer.early_stop_callback.early_stopping. best_monitored_value) model = self.trainer.model data_parallel = registry.get("data_parallel") or registry.get( "distributed") if data_parallel is True: model = model.module ckpt = { "model": model.state_dict(), "optimizer": self.trainer.optimizer.state_dict(), "best_iteration": best_iteration, "current_iteration": iteration, "current_epoch": self.trainer.current_epoch, "num_updates": update, "best_update": best_update, "best_metric_value": best_metric, # Convert to container to avoid any dependencies "config": OmegaConf.to_container(self.config, resolve=True), } lr_scheduler = self.trainer.lr_scheduler_callback._scheduler if lr_scheduler is not None: ckpt["lr_scheduler"] = lr_scheduler.state_dict() if self.git_repo: git_metadata_dict = self._get_vcs_fields() ckpt.update(git_metadata_dict) with PathManager.open(ckpt_filepath, "wb") as f: torch.save(ckpt, f) if update_best: with PathManager.open(best_ckpt_filepath, "wb") as f: torch.save(ckpt, f) # Save current always with PathManager.open(current_ckpt_filepath, "wb") as f: torch.save(ckpt, f) # Remove old checkpoints if max_to_keep is set if self.max_to_keep > 0: if len(self.saved_iterations) == self.max_to_keep: self.remove(self.saved_iterations.pop(0)) self.saved_iterations.append(update)
def _init_classifier(self): num_hidden = self.config.text_embedding.num_hidden num_choices = registry.get(self._datasets[0] + "_num_final_outputs") dropout = self.config.classifier.dropout self.classifier = WeightNormClassifier(num_hidden, num_choices, num_hidden * 2, dropout)
def __init__(self, loss_list): super().__init__() self.losses = [] self._evaluation_predict = registry.get("config").evaluation.predict for loss in loss_list: self.losses.append(MMFLoss(loss))
def is_xla(): return registry.get("is_xla", no_warning=True)
def setup_logger( output: str = None, color: bool = True, name: str = "mmf", disable: bool = False, clear_handlers=True, *args, **kwargs, ): """ Initialize the MMF logger and set its verbosity level to "INFO". Outside libraries shouldn't call this in case they have set there own logging handlers and setup. If they do, and don't want to clear handlers, pass clear_handlers options. The initial version of this function was taken from D2 and adapted for MMF. Args: output (str): a file name or a directory to save log. If ends with ".txt" or ".log", assumed to be a file name. Default: Saved to file <save_dir/logs/log_[timestamp].txt> color (bool): If false, won't log colored logs. Default: true name (str): the root module name of this logger. Defaults to "mmf". clear_handlers (bool): If false, won't clear existing handlers. Returns: logging.Logger: a logger """ if disable: return None logger = logging.getLogger(name) logger.propagate = False logging.captureWarnings(True) warnings_logger = logging.getLogger("py.warnings") plain_formatter = logging.Formatter( "%(asctime)s | %(levelname)s | %(name)s : %(message)s", datefmt="%Y-%m-%dT%H:%M:%S", ) distributed_rank = get_rank() handlers = [] logging_level = registry.get("config").training.logger_level.upper() if distributed_rank == 0: logger.setLevel(logging_level) ch = logging.StreamHandler(stream=sys.stdout) ch.setLevel(logging_level) if color: formatter = ColorfulFormatter( colored("%(asctime)s | %(name)s: ", "green") + "%(message)s", datefmt="%Y-%m-%dT%H:%M:%S", ) else: formatter = plain_formatter ch.setFormatter(formatter) logger.addHandler(ch) warnings_logger.addHandler(ch) handlers.append(ch) # file logging: all workers if output is None: output = setup_output_folder() if output is not None: if output.endswith(".txt") or output.endswith(".log"): filename = output else: filename = os.path.join(output, "train.log") if distributed_rank > 0: filename = filename + f".rank{distributed_rank}" PathManager.mkdirs(os.path.dirname(filename)) fh = logging.StreamHandler(_cached_log_stream(filename)) fh.setLevel(logging_level) fh.setFormatter(plain_formatter) logger.addHandler(fh) warnings_logger.addHandler(fh) handlers.append(fh) # Slurm/FB output, only log the main process if "train.log" not in filename and distributed_rank == 0: save_dir = get_mmf_env(key="save_dir") filename = os.path.join(save_dir, "train.log") sh = logging.StreamHandler(_cached_log_stream(filename)) sh.setLevel(logging_level) sh.setFormatter(plain_formatter) logger.addHandler(sh) warnings_logger.addHandler(sh) handlers.append(sh) logger.info(f"Logging to: {filename}") # Remove existing handlers to add MMF specific handlers if clear_handlers: for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) # Now, add our handlers. logging.basicConfig(level=logging_level, handlers=handlers) registry.register("writer", logger) return logger
def import_user_module(user_dir: str): """Given a user dir, this function imports it as a module. This user_module is expected to have an __init__.py at its root. You can use import_files to import your python files easily in __init__.py Args: user_dir (str): directory which has to be imported """ from mmf.common.registry import registry from mmf.utils.general import get_absolute_path # noqa logger = logging.getLogger(__name__) if user_dir: if registry.get("__mmf_user_dir_imported__", no_warning=True): logger.info(f"User dir {user_dir} already imported. Skipping.") return # Allow loading of files as user source if user_dir.endswith(".py"): user_dir = user_dir[:-3] dot_path = ".".join(user_dir.split(os.path.sep)) # In case of abspath which start from "/" the first char # will be "." which turns it into relative module which # find_spec doesn't like if os.path.isabs(user_dir): dot_path = dot_path[1:] try: dot_spec = importlib.util.find_spec(dot_path) except ModuleNotFoundError: dot_spec = None abs_user_dir = get_absolute_path(user_dir) module_parent, module_name = os.path.split(abs_user_dir) # If dot path is found in sys.modules, or path can be directly # be imported, we don't need to play jugglery with actual path if dot_path in sys.modules or dot_spec is not None: module_name = dot_path else: user_dir = abs_user_dir logger.info(f"Importing from {user_dir}") if module_name != dot_path: # Since dot path hasn't been found or can't be imported, # we can try importing the module by changing sys path # to the parent sys.path.insert(0, module_parent) importlib.import_module(module_name) sys.modules["mmf_user_dir"] = sys.modules[module_name] # Register config for user's model and dataset config # relative path resolution config = registry.get("config") if config is None: registry.register( "config", OmegaConf.create({"env": {"user_dir": user_dir}}) ) else: with open_dict(config): config.env.user_dir = user_dir registry.register("__mmf_user_dir_imported__", True)
def save(self, update, iteration=None, update_best=False): # Only save in main process # For xla we use xm.save method # Which ensures that actual checkpoint saving happens # only for the master node. # The method also takes care of all the necessary synchronization if not is_master() and not is_xla(): return logger.info("Checkpoint save operation started!") if not iteration: iteration = update ckpt_filepath = os.path.join(self.models_foldername, "model_%d.ckpt" % update) best_ckpt_filepath = os.path.join( self.ckpt_foldername, self.ckpt_prefix + "best.ckpt" ) current_ckpt_filepath = os.path.join( self.ckpt_foldername, self.ckpt_prefix + "current.ckpt" ) best_iteration = ( self.trainer.early_stop_callback.early_stopping.best_monitored_iteration ) best_update = ( self.trainer.early_stop_callback.early_stopping.best_monitored_update ) best_metric = ( self.trainer.early_stop_callback.early_stopping.best_monitored_value ) model = self.trainer.model data_parallel = registry.get("data_parallel") or registry.get("distributed") fp16_scaler = getattr(self.trainer, "scaler", None) fp16_scaler_dict = None if fp16_scaler is not None: fp16_scaler_dict = fp16_scaler.state_dict() if data_parallel is True: model = model.module ckpt = { "model": model.state_dict(), "optimizer": self.trainer.optimizer.state_dict(), "best_iteration": best_iteration, "current_iteration": iteration, "current_epoch": self.trainer.current_epoch, "num_updates": update, "best_update": best_update, "best_metric_value": best_metric, "fp16_scaler": fp16_scaler_dict, # Convert to container to avoid any dependencies "config": OmegaConf.to_container(self.config, resolve=True), } lr_scheduler = self.trainer.lr_scheduler_callback._scheduler if lr_scheduler is not None: ckpt["lr_scheduler"] = lr_scheduler.state_dict() if self.git_repo: git_metadata_dict = self._get_vcs_fields() ckpt.update(git_metadata_dict) with PathManager.open(ckpt_filepath, "wb") as f: self.save_func(ckpt, f) if update_best: logger.info("Saving best checkpoint") with PathManager.open(best_ckpt_filepath, "wb") as f: self.save_func(ckpt, f) # Save current always logger.info("Saving current checkpoint") with PathManager.open(current_ckpt_filepath, "wb") as f: self.save_func(ckpt, f) # Remove old checkpoints if max_to_keep is set if self.max_to_keep > 0: if len(self.saved_iterations) == self.max_to_keep: self.remove(self.saved_iterations.pop(0)) self.saved_iterations.append(update) logger.info("Checkpoint save operation finished!")
def __init__(self, config): super().__init__(config) self.mmt_config = BertConfig(**self.config.mmt) self._datasets = registry.get("config").datasets.split(",")
def __init__( self, num_train_data, max_updates, max_epochs, config=None, optimizer=None, update_frequency=1, batch_size=1, batch_size_per_device=None, fp16=False, on_update_end_fn=None, scheduler_config=None, grad_clipping_config=None, ): if config is None: self.config = OmegaConf.create( { "training": { "detect_anomaly": False, "evaluation_interval": 10000, "update_frequency": update_frequency, "fp16": fp16, "batch_size": batch_size, "batch_size_per_device": batch_size_per_device, } } ) self.training_config = self.config.training else: self.training_config = config.training self.config = config # Load batch size with custom config and cleanup original_config = registry.get("config") registry.register("config", self.config) batch_size = get_batch_size() registry.register("config", original_config) if max_updates is not None: self.training_config["max_updates"] = max_updates if max_epochs is not None: self.training_config["max_epochs"] = max_epochs self.model = SimpleModel({"in_dim": 1}) self.model.build() if torch.cuda.is_available(): self.model = self.model.cuda() self.device = "cuda" else: self.device = "cpu" self.distributed = False self.dataset_loader = MagicMock() self.dataset_loader.seed_sampler = MagicMock(return_value=None) self.dataset_loader.prepare_batch = lambda x: SampleList(x) if optimizer is None: self.optimizer = MagicMock() self.optimizer.step = MagicMock(return_value=None) self.optimizer.zero_grad = MagicMock(return_value=None) else: self.optimizer = optimizer if scheduler_config: config.training.lr_scheduler = True config.scheduler = scheduler_config self.lr_scheduler_callback = LRSchedulerCallback(config, self) self.callbacks.append(self.lr_scheduler_callback) on_update_end_fn = ( on_update_end_fn if on_update_end_fn else self.lr_scheduler_callback.on_update_end ) if grad_clipping_config: self.training_config.clip_gradients = True self.training_config.max_grad_l2_norm = grad_clipping_config[ "max_grad_l2_norm" ] self.training_config.clip_norm_mode = grad_clipping_config["clip_norm_mode"] dataset = NumbersDataset(num_train_data) self.train_loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=False, ) self.train_loader.current_dataset = dataset self.on_batch_start = MagicMock(return_value=None) self.on_update_start = MagicMock(return_value=None) self.logistics_callback = MagicMock(return_value=None) self.logistics_callback.log_interval = MagicMock(return_value=None) self.on_batch_end = MagicMock(return_value=None) self.on_update_end = ( on_update_end_fn if on_update_end_fn else MagicMock(return_value=None) ) self.meter = Meter() self.after_training_loop = MagicMock(return_value=None) self.on_validation_start = MagicMock(return_value=None) self.evaluation_loop = MagicMock(return_value=(None, None)) self.scaler = torch.cuda.amp.GradScaler(enabled=False) self.val_loader = MagicMock(return_value=None) self.early_stop_callback = MagicMock(return_value=None) self.on_validation_end = MagicMock(return_value=None) self.metrics = MagicMock(return_value=None)
def _build_word_embedding(self): assert len(self._datasets) > 0 text_processor = registry.get(self._datasets[0] + "_text_processor") vocab = text_processor.vocab self.word_embedding = vocab.get_embedding(torch.nn.Embedding, embedding_dim=300)
def _load(self, file, force=False, load_pretrained=False): tp = self.config.training self.trainer.writer.write("Loading checkpoint") ckpt = self._torch_load(file) data_parallel = registry.get("data_parallel") or registry.get( "distributed") if "model" in ckpt: ckpt_model = ckpt["model"] else: ckpt_model = ckpt ckpt = {"model": ckpt} pretrained_mapping = tp.pretrained_mapping if load_pretrained is False or force is True: pretrained_mapping = {} new_dict = {} # TODO: Move to separate function for attr in ckpt_model: new_attr = attr if "fa_history" in attr: new_attr = new_attr.replace("fa_history", "fa_context") if data_parallel is False and attr.startswith("module."): # In case the ckpt was actually a data parallel model # replace first module. from dataparallel with empty string new_dict[new_attr.replace("module.", "", 1)] = ckpt_model[attr] elif data_parallel is not False and not attr.startswith("module."): new_dict["module." + new_attr] = ckpt_model[attr] else: new_dict[new_attr] = ckpt_model[attr] if len(pretrained_mapping.items()) == 0: final_dict = new_dict self.trainer.model.load_state_dict(final_dict, strict=False) if "optimizer" in ckpt: self.trainer.optimizer.load_state_dict(ckpt["optimizer"]) else: warnings.warn("'optimizer' key is not present in the " "checkpoint asked to be loaded. Skipping.") self.trainer.early_stopping.init_from_checkpoint(ckpt) self.trainer.writer.write("Checkpoint loaded") if "best_update" in ckpt: if tp.resume_best: self.trainer.num_updates = ckpt.get( "best_update", self.trainer.num_updates) self.trainer.current_iteration = ckpt.get( "best_iteration", self.trainer.current_iteration) else: self.trainer.num_updates = ckpt.get( "num_updates", self.trainer.num_updates) self.trainer.current_iteration = ckpt.get( "current_iteration", self.trainer.current_iteration) self.trainer.current_epoch = ckpt.get( "current_epoch", self.trainer.current_epoch) elif "best_iteration" in ckpt: # Preserve old behavior for old checkpoints where we always # load best iteration if tp.resume_best and "current_iteration" in ckpt: self.trainer.current_iteration = ckpt["current_iteration"] else: self.trainer.current_iteration = ckpt.get( "best_iteration", self.trainer.current_iteration) self.trainer.num_updates = self.trainer.current_iteration registry.register("current_iteration", self.trainer.current_iteration) registry.register("num_updates", self.trainer.num_updates) self.trainer.current_epoch = ckpt.get("best_epoch", self.trainer.current_epoch) registry.register("current_epoch", self.trainer.current_epoch) else: final_dict = {} model = self.trainer.model own_state = model.state_dict() for key, value in pretrained_mapping.items(): key += "." value += "." for attr in new_dict: for own_attr in own_state: formatted_attr = model.format_state_key(attr) if (key in formatted_attr and value in own_attr and formatted_attr.replace( key, "") == own_attr.replace(value, "")): self.trainer.writer.write("Copying " + attr + " " + own_attr) own_state[own_attr].copy_(new_dict[attr]) self.trainer.writer.write("Pretrained model loaded")