def load_archive(archive_file: str, device=None, weights_file: str = None) -> Archive: """ Instantiates an Archive from an archived `tar.gz` file. Parameters ---------- archive_file: ``str`` The archive file to load the model from. weights_file: ``str``, optional (default = None) The weights file to use. If unspecified, weights.th in the archive_file will be used. device: ``None`` or PyTorch device object. """ # redirect to the cache, if necessary resolved_archive_file = cached_path(archive_file) if resolved_archive_file == archive_file: logger.info(f"loading archive file {archive_file}") else: logger.info( f"loading archive file {archive_file} from cache at {resolved_archive_file}" ) tempdir = None if os.path.isdir(resolved_archive_file): serialization_dir = resolved_archive_file else: # Extract archive to temp dir tempdir = tempfile.mkdtemp() logger.info( f"extracting archive file {resolved_archive_file} to temp dir {tempdir}" ) with tarfile.open(resolved_archive_file, 'r:gz') as archive: archive.extractall(tempdir) serialization_dir = tempdir # Load config config = Params.from_file(os.path.join(serialization_dir, CONFIG_NAME)) config.loading_from_archive = True if weights_file: weights_path = weights_file else: weights_path = os.path.join(serialization_dir, _WEIGHTS_NAME) # Instantiate model. Use a duplicate of the config, as it will get consumed. model = Model.load(config, weights_file=weights_path, serialization_dir=serialization_dir, device=device) if tempdir: # Clean up temp dir shutil.rmtree(tempdir) return Archive(model=model, config=config)
def create_serialization_dir(params: Params) -> None: """ This function creates the serialization directory if it doesn't exist. If it already exists and is non-empty, then it verifies that we're recovering from a training with an identical configuration. Parameters ---------- params: ``Params`` A parameter object specifying an AllenNLP Experiment. serialization_dir: ``str`` The directory in which to save results and logs. recover: ``bool`` If ``True``, we will try to recover from an existing serialization directory, and crash if the directory doesn't exist, or doesn't match the configuration we're given. """ serialization_dir = params['environment']['serialization_dir'] recover = params['environment']['recover'] if os.path.exists(serialization_dir) and os.listdir(serialization_dir): if not recover: raise ConfigurationError(f"Serialization directory ({serialization_dir}) already exists and is " f"not empty. Specify --recover to recover training from existing output.") logger.info(f"Recovering from prior training at {serialization_dir}.") recovered_config_file = os.path.join(serialization_dir, CONFIG_NAME) if not os.path.exists(recovered_config_file): raise ConfigurationError("The serialization directory already exists but doesn't " "contain a config.json. You probably gave the wrong directory.") else: loaded_params = Params.from_file(recovered_config_file) if params != loaded_params: raise ConfigurationError("Training configuration does not match the configuration we're " "recovering from.") # In the recover mode, we don't need to reload the pre-trained embeddings. remove_pretrained_embedding_params(params) else: if recover: raise ConfigurationError(f"--recover specified but serialization_dir ({serialization_dir}) " "does not exist. There is nothing to recover from.") os.makedirs(serialization_dir, exist_ok=True) params.to_file(os.path.join(serialization_dir, CONFIG_NAME))
# if we have completed an epoch, try to create a model archive. if os.path.exists(os.path.join(serialization_dir, _DEFAULT_WEIGHTS)): logger.info("Training interrupted by the user. Attempting to create " "a model archive using the current best epoch weights.") archive_model(serialization_dir) raise # Now tar up results archive_model(serialization_dir) logger.info("Loading the best epoch weights.") best_model_state_path = os.path.join(serialization_dir, 'best.th') best_model_state = torch.load(best_model_state_path) best_model = model if not isinstance(best_model, torch.nn.DataParallel): best_model_state = {re.sub(r'^module\.', '', k):v for k, v in best_model_state.items()} best_model.load_state_dict(best_model_state) return best_model if __name__ == "__main__": parser = argparse.ArgumentParser('train.py') parser.add_argument('params', help='Parameters YAML file.') args = parser.parse_args() params = Params.from_file(args.params) logger.info(params) train_model(params)