示例#1
0
def restore_checkpoint(model, optimizer, serialization_dir, learning_rate_scheduler=None):
    """
    Restores a model from a serialization_dir to the last saved checkpoint.
    This includes an epoch count and optimizer state, which is serialized separately
    from  model parameters. This function should only be used to continue training -
    if you wish to load a model for inference/load parts of a model into a new
    computation graph, you should use the native Pytorch functions:
    `` model.load_state_dict(torch.load("/path/to/model/weights.th"))``
    If ``self._serialization_dir`` does not exist or does not contain any checkpointed weights,
    this function will do nothing and return 0.
    Returns
    -------
    epoch: int
        The epoch at which to resume training, which should be one after the epoch
        in the saved training state.
    """
    latest_checkpoint = find_latest_checkpoint(serialization_dir)

    if latest_checkpoint is None:
        # No checkpoint to restore, start at 0
        return 0, []

    model_path, training_state_path = latest_checkpoint

    # Load the parameters onto CPU, then transfer to GPU.
    # This avoids potential OOM on GPU for large models that
    # load parameters onto GPU then make a new GPU copy into the parameter
    # buffer. The GPU transfer happens implicitly in load_state_dict.
    model_state = torch.load(model_path, map_location=device_mapping(-1))
    training_state = torch.load(training_state_path, map_location=device_mapping(-1))
    if isinstance(model, DataParallel):
        model.module.load_state_dict(model_state)
    else:
        model.load_state_dict(model_state)

    # idk this is always bad luck for me
    optimizer.load_state_dict(training_state["optimizer"])

    if learning_rate_scheduler is not None and "learning_rate_scheduler" in training_state:
        learning_rate_scheduler.lr_scheduler.load_state_dict(
            training_state["learning_rate_scheduler"])
    move_optimizer_to_cuda(optimizer)

    # We didn't used to save `validation_metric_per_epoch`, so we can't assume
    # that it's part of the trainer state. If it's not there, an empty list is all
    # we can do.
    if "val_metric_per_epoch" not in training_state:
        print("trainer state `val_metric_per_epoch` not found, using empty list")
        val_metric_per_epoch: []
    else:
        val_metric_per_epoch = training_state["val_metric_per_epoch"]

    if isinstance(training_state["epoch"], int):
        epoch_to_return = training_state["epoch"] + 1
    else:
        epoch_to_return = int(training_state["epoch"].split('.')[0]) + 1
    return epoch_to_return, val_metric_per_epoch
dataset = reader.read("")
vocab = Vocabulary.from_instances(dataset)

print(vocab.get_vocab_size())

opts = ModelOptions()

reader = TestReader(vocab)
reader.set_compute_nnrank_features(True)
dataset = reader.read("")
text_embedder = Text_Embedding(opts, vocab)

nnrank = CitationRanker(vocab, opts, text_embedder)

iterator = BasicIterator()
iterator.index_with(vocab)

optimizer = torch.optim.SGD(nnrank.parameters(), lr=0.1)
move_optimizer_to_cuda(optimizer)

trainer = Trainer(model=nnrank,
                  optimizer=optimizer,
                  iterator=iterator,
                  train_dataset=dataset,
                  validation_dataset=dataset,
                  patience=1000,
                  num_epochs=10,
                  summary_interval=2,
                  cuda_device=0)

trainer.train()