예제 #1
0
    def save(self, epoch, valid_ppl, itr=None):

        opt = self.opt
        model = self.model
        dicts = self.dicts

        if isinstance(model, torch.nn.parallel.DistributedDataParallel):
            model_state_dict = self.model.module.state_dict()
        else:
            model_state_dict = self.model.state_dict()
        optim_state_dict = self.optim.state_dict()

        if itr:
            itr_state_dict = itr.state_dict()
        else:
            itr_state_dict = None

        #  drop a checkpoint
        checkpoint = {
            'model': model_state_dict,
            'dicts': dicts,
            'opt': opt,
            'epoch': epoch,
            'itr': itr_state_dict,
            'optim': optim_state_dict,
            'scaler': self.grad_scaler.state_dict()
        }

        file_name = '%s_ppl_%.6f_e%.2f.pt' % (opt.save_model, valid_ppl, epoch)
        print('Writing to %s' % file_name)
        torch.save(checkpoint, file_name)

        # check the save directory here
        checkpoint_dir = os.path.dirname(opt.save_model)
        existed_save_files = checkpoint_paths(checkpoint_dir)
        for save_file in existed_save_files[opt.keep_save_files:]:
            print(" * Deleting old save file %s ...." % save_file)
            os.remove(save_file)
예제 #2
0
    def save(self, epoch, valid_ppl, itr=None):

        opt = self.opt
        model = self.model
        dicts = self.dicts

        # model_state_dict = self.model.state_dict()
        # optim_state_dict = self.optim.state_dict()
        model_state_dict, optim_state_dict, amp_state_dict = self.model_wrapper.state_dict(
        )

        if itr:
            itr_state_dict = itr.state_dict()
        else:
            itr_state_dict = None

        #  drop a checkpoint
        checkpoint = {
            'model': model_state_dict,
            'dicts': dicts,
            'opt': opt,
            'epoch': epoch,
            'itr': itr_state_dict,
            'optim': optim_state_dict,
            'amp': amp_state_dict
        }

        file_name = '%s_ppl_%.6f_e%.2f.pt' % (opt.save_model, valid_ppl, epoch)
        print('Writing to %s' % file_name)
        torch.save(checkpoint, file_name)

        # check the save directory here
        checkpoint_dir = os.path.dirname(opt.save_model)
        existed_save_files = checkpoint_paths(checkpoint_dir)
        for save_file in existed_save_files[opt.keep_save_files:]:
            print(" * Deleting old save file %s ...." % save_file)
            os.remove(save_file)
def main():

    opt = parser.parse_args()

    opt.cuda = opt.gpu > -1

    if opt.cuda:
        torch.cuda.set_device(opt.gpu)

    path = opt.models

    existed_save_files = checkpoint_paths(path)

    # print(existed_save_files)
    models = existed_save_files
    # # opt.model should be a string of models, split by |
    # models = list()

    # take the top
    models = models[:opt.top]

    # print(models)
    n_models = len(models)
    print("Firstly load the best model from %s ..." % models[0])
    checkpoint = torch.load(models[0],
                            map_location=lambda storage, loc: storage)

    if 'optim' in checkpoint:
        del checkpoint['optim']

    main_checkpoint = checkpoint

    best_checkpoint = main_checkpoint

    # print("Saving best model to %s" % opt.output + ".top")

    # torch.save(best_checkpoint, opt.output + ".top")

    model_opt = checkpoint['opt']
    # 下面load_state_dict有依次加载最优的五个模型,所以这里只用构建对象
    model_opt.not_load_bert_state = True
    dicts = checkpoint['dicts']
    main_model = custom_build_model(model_opt, checkpoint['dicts'], lm=opt.lm)

    main_model.load_state_dict(checkpoint['model'])

    if opt.cuda:
        main_model = main_model.cuda()
    print("Then load the other %d models ..." % (n_models - 1))
    for i in range(1, len(models)):

        model = models[i]
        print("Loading model from %s ..." % models[i])
        checkpoint = torch.load(model,
                                map_location=lambda storage, loc: storage)

        model_opt = checkpoint['opt']
        model_opt.not_load_bert_state = True

        # delete optim information to save GPU memory
        if 'optim' in checkpoint:
            del checkpoint['optim']

        current_model = custom_build_model(model_opt,
                                           checkpoint['dicts'],
                                           lm=opt.lm)

        current_model.load_state_dict(checkpoint['model'])

        if opt.cuda:
            current_model = current_model.cuda()

        if opt.method == 'mean':
            # Sum the parameter values
            for (main_param, param) in zip(main_model.parameters(),
                                           current_model.parameters()):
                main_param.data.add_(param.data)
        elif opt.method == 'gmean':
            # Take the geometric mean of parameter values
            for (main_param, param) in zip(main_model.parameters(),
                                           current_model.parameters()):
                main_param.data.mul_(param.data)
        else:
            raise NotImplementedError

    # Normalizing
    if opt.method == 'mean':
        for main_param in main_model.parameters():
            main_param.data.div_(n_models)
    elif opt.method == 'gmean':
        for main_param in main_model.parameters():
            main_param.data.pow_(1. / n_models)

    # Saving
    model_state_dict = main_model.state_dict()

    save_checkpoint = {
        'model': model_state_dict,
        'dicts': dicts,
        'opt': model_opt,
        'epoch': -1,
        'iteration': -1,
        'batchOrder': None,
        'optim': None
    }

    print("Saving averaged model to %s" % opt.output)

    torch.save(save_checkpoint, opt.output)
def main():

    opt = parser.parse_args()

    opt.cuda = opt.gpu > -1

    if opt.cuda:
        torch.cuda.set_device(opt.gpu)

    path = opt.models

    if not opt.sort_by_date:
        existed_save_files = checkpoint_paths(path)
    else:
        existed_save_files = glob.glob(path + "/" + "*.pt")
        existed_save_files.sort(key=os.path.getmtime)
        print("\n".join(existed_save_files))

    # print(existed_save_files)
    models = existed_save_files
    # # opt.model should be a string of models, split by |
    # models = list()
    #

    # take the top
    models = models[:opt.top]

    # print(models)
    #
    n_models = len(models)
    #

    # checkpoint for main model
    checkpoint = torch.load(models[0],
                            map_location=lambda storage, loc: storage)

    if 'optim' in checkpoint:
        del checkpoint['optim']

    main_checkpoint = checkpoint

    # best_checkpoint = {
    #     'model': deepcopy(main_checkpoint['model']),
    #     'dicts': main_checkpoint['dicts'],
    #     'opt': main_checkpoint['opt'],
    #     'epoch': -1,
    #     'iteration': -1,
    #     'batchOrder': None,
    #     'optim': None
    # }
    best_checkpoint = main_checkpoint

    # print("Saving best model to %s" % opt.output + ".top")

    # torch.save(best_checkpoint, opt.output + ".top")

    model_opt = checkpoint['opt']

    dicts = checkpoint['dicts']

    # only create the object
    model_opt.enc_state_dict = None
    model_opt.dec_state_dict = None

    print(model_opt.layers)

    main_model = custom_build_model(model_opt, checkpoint['dicts'], lm=opt.lm)

    print("Loading main model from %s ..." % models[0])
    main_model.load_state_dict(checkpoint['model'])

    if opt.cuda:
        main_model = main_model.cuda()

    for i in range(1, len(models)):

        model = models[i]
        # checkpoint for  models[i])
        checkpoint = torch.load(model,
                                map_location=lambda storage, loc: storage)

        model_opt = checkpoint['opt']

        # model_opt.enc_not_load_state = True
        # model_opt.dec_not_load_state = True
        model_opt.enc_state_dict = None
        model_opt.dec_state_dict = None

        # delete optim information to save GPU memory
        if 'optim' in checkpoint:
            del checkpoint['optim']

        current_model = custom_build_model(model_opt,
                                           checkpoint['dicts'],
                                           lm=opt.lm)

        print("Loading model from %s ..." % models[i])
        current_model.load_state_dict(checkpoint['model'])

        if opt.cuda:
            current_model = current_model.cuda()

        if opt.method == 'mean':
            # Sum the parameter values
            for (main_param, param) in zip(main_model.parameters(),
                                           current_model.parameters()):
                main_param.data.add_(param.data)
        elif opt.method == 'gmean':
            # Take the geometric mean of parameter values
            for (main_param, param) in zip(main_model.parameters(),
                                           current_model.parameters()):
                main_param.data.mul_(param.data)
        else:
            raise NotImplementedError

    # Normalizing
    if opt.method == 'mean':
        for main_param in main_model.parameters():
            main_param.data.div_(n_models)
    elif opt.method == 'gmean':
        for main_param in main_model.parameters():
            main_param.data.pow_(1. / n_models)

    # Saving
    model_state_dict = main_model.state_dict()

    save_checkpoint = {
        'model': model_state_dict,
        'dicts': dicts,
        'opt': model_opt,
        'epoch': -1,
        'iteration': -1,
        'batchOrder': None,
        'optim': None
    }

    print("Saving averaged model to %s" % opt.output)

    torch.save(save_checkpoint, opt.output)