示例#1
0
def load_archive(archive_file: str,
                 device=None,
                 weights_file: str = None) -> Archive:
    """
    Instantiates an Archive from an archived `tar.gz` file.
    Parameters
    ----------
    archive_file: ``str``
        The archive file to load the model from.
    weights_file: ``str``, optional (default = None)
        The weights file to use.  If unspecified, weights.th in the archive_file will be used.
    device: ``None`` or PyTorch device object.
    """
    # redirect to the cache, if necessary
    resolved_archive_file = cached_path(archive_file)

    if resolved_archive_file == archive_file:
        logger.info(f"loading archive file {archive_file}")
    else:
        logger.info(
            f"loading archive file {archive_file} from cache at {resolved_archive_file}"
        )

    tempdir = None
    if os.path.isdir(resolved_archive_file):
        serialization_dir = resolved_archive_file
    else:
        # Extract archive to temp dir
        tempdir = tempfile.mkdtemp()
        logger.info(
            f"extracting archive file {resolved_archive_file} to temp dir {tempdir}"
        )
        with tarfile.open(resolved_archive_file, 'r:gz') as archive:
            archive.extractall(tempdir)

        serialization_dir = tempdir

    # Load config
    config = Params.from_file(os.path.join(serialization_dir, CONFIG_NAME))
    config.loading_from_archive = True

    if weights_file:
        weights_path = weights_file
    else:
        weights_path = os.path.join(serialization_dir, _WEIGHTS_NAME)

    # Instantiate model. Use a duplicate of the config, as it will get consumed.
    model = Model.load(config,
                       weights_file=weights_path,
                       serialization_dir=serialization_dir,
                       device=device)

    if tempdir:
        # Clean up temp dir
        shutil.rmtree(tempdir)

    return Archive(model=model, config=config)
示例#2
0
def create_serialization_dir(params: Params) -> None:
    """
    This function creates the serialization directory if it doesn't exist.  If it already exists
    and is non-empty, then it verifies that we're recovering from a training with an identical configuration.
    Parameters
    ----------
    params: ``Params``
        A parameter object specifying an AllenNLP Experiment.
    serialization_dir: ``str``
        The directory in which to save results and logs.
    recover: ``bool``
        If ``True``, we will try to recover from an existing serialization directory, and crash if
        the directory doesn't exist, or doesn't match the configuration we're given.
    """
    serialization_dir = params['environment']['serialization_dir']
    recover = params['environment']['recover']
    if os.path.exists(serialization_dir) and os.listdir(serialization_dir):
        if not recover:
            raise ConfigurationError(
                f"Serialization directory ({serialization_dir}) already exists and is "
                f"not empty. Specify --recover to recover training from existing output."
            )

        logger.info(f"Recovering from prior training at {serialization_dir}.")

        recovered_config_file = os.path.join(serialization_dir, CONFIG_NAME)
        if not os.path.exists(recovered_config_file):
            raise ConfigurationError(
                "The serialization directory already exists but doesn't "
                "contain a config.json. You probably gave the wrong directory."
            )
        else:
            loaded_params = Params.from_file(recovered_config_file)

            if params != loaded_params:
                raise ConfigurationError(
                    "Training configuration does not match the configuration we're "
                    "recovering from.")

            # In the recover mode, we don't need to reload the pre-trained embeddings.
            remove_pretrained_embedding_params(params)
    else:
        if recover:
            raise ConfigurationError(
                f"--recover specified but serialization_dir ({serialization_dir}) "
                "does not exist.  There is nothing to recover from.")
        os.makedirs(serialization_dir, exist_ok=True)
        params.to_file(os.path.join(serialization_dir, CONFIG_NAME))
示例#3
0
        # if we have completed an epoch, try to create a model archive.
        if os.path.exists(os.path.join(serialization_dir, _DEFAULT_WEIGHTS)):
            logger.info("Training interrupted by the user. Attempting to create "
                         "a model archive using the current best epoch weights.")
            archive_model(serialization_dir)
        raise

    # Now tar up results
    archive_model(serialization_dir)

    logger.info("Loading the best epoch weights.")
    best_model_state_path = os.path.join(serialization_dir, 'best.th')
    best_model_state = torch.load(best_model_state_path)
    best_model = model
    if not isinstance(best_model, torch.nn.DataParallel):
        best_model_state = {re.sub(r'^module\.', '', k):v for k, v in best_model_state.items()}
    best_model.load_state_dict(best_model_state)

    return best_model


if __name__ == "__main__":
    parser = argparse.ArgumentParser('train.py')
    parser.add_argument('params', help='Parameters YAML file.')
    args = parser.parse_args()

    params = Params.from_file(args.params)
    logger.info(params)

    train_model(params)