Ejemplo n.º 1
0
def load_averaged_checkpoint_warm_start(checkpoint_path, model, optimizer, scheduler):
    # checkpoint_path is dir in this function
    assert os.path.isdir(checkpoint_path)
    list_checkpoints = last_n_checkpoints(checkpoint_path, args.average_checkpoint)
    iteration = 0
    model_for_loading = average_checkpoints(list_checkpoints, args.epsilon)['model']
    model_dict = model.state_dict()
    pretrained_dict = {k: v for k, v in model_for_loading.items() if
                       (k in model_dict) and (model_dict[k].shape == model_for_loading[k].shape)}
    model_dict.update(pretrained_dict)
    missing_and_unexpected_keys = model.load_state_dict(pretrained_dict, strict=False)
    print("WARNING: only part of the model loaded. below are missing and unexpected keys, make sure that they are correct:")
    print(missing_and_unexpected_keys)
    print("Loaded checkpoint '{}' (iteration {})".format(
        checkpoint_path, iteration))
    return model, optimizer, scheduler, iteration
Ejemplo n.º 2
0
def load_averaged_checkpoint(checkpoint_path, model, num_avg_ckpt):
    # checkpoint_path is dir in this function
    assert os.path.isdir(checkpoint_path)
    list_checkpoints = last_n_checkpoints(checkpoint_path, num_avg_ckpt)
    iteration = torch.load(list_checkpoints[0], map_location='cpu')['iteration']
    model_for_loading = average_checkpoints(list_checkpoints, args.epsilon)['model']
    try:
        model.load_state_dict(model_for_loading)
    except RuntimeError:
        print("DataParallel weight detected. loading...")
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in model_for_loading.items():
            name = k.replace("module.", "")  # remove `module.`
            new_state_dict[name] = v
        # load params
        model.load_state_dict(new_state_dict)
    print("Loaded averaged checkpoint from '{}' (last iteration {})" .format(
          checkpoint_path, iteration))
    return model, iteration
Ejemplo n.º 3
0
def avg_ckpts(path, n, output, upper_bound=None):
    inputs = last_n_checkpoints(path, n, upper_bound)
    new_state = average_checkpoints(inputs)
    torch.save(new_state, output)
Ejemplo n.º 4
0
           epoch=0,
           prefix='last_model')

model.cuda()
model.eval()

# test_bleu = validate(model, test_iter, SRC, TGT,
#                         BOS_WORD, EOS_WORD, BLANK_WORD, config['max_len'], logging=True)
# print(f"Test Bleu score: {test_bleu}")

paths = [
    'en-de__Sep-30-2019_09-19.pt', 'en-de__Sep-29-2019_15-35.pt',
    'en-de__Sep-29-2019_19-13.pt', 'en-de__Sep-29-2019_23-55.pt',
    'en-de__Sep-30-2019_04-37.pt'
]
average_model_state = average_checkpoints(paths)

model.load_state_dict(average_model_state)
print(f'Model loaded with with averaged parameters from {len(paths)} models.')
save_model(model,
           None,
           loss=0,
           src_field=SRC,
           tgt_field=TGT,
           updates=0,
           epoch=0,
           prefix='average')

model.cuda()
model.eval()