예제 #1
0
def load_model(model_path, with_gpu):  # 载入模型、模型参数
    logger.info("Loading checkpoint: {} ...".format(model_path))
    checkpoint = torch.load(model_path)
    if not checkpoint:
        raise RuntimeError('No checkpoint found.')
    config = checkpoint['config']

    model = FOTSModel(config)

    pretrained_dict = checkpoint['state_dict']  # 预训练模型的state_dict
    model_dict = model.state_dict()  # 当前用来训练的模型的state_dict

    if pretrained_dict.keys() != model_dict.keys():  # 需要进行参数的适配
        print('Parameters are inconsistant, adapting model parameters ...')
        # 在合并前(update),需要去除pretrained_dict一些不需要的参数
        # 只含有识别分支的预训练模型参数字典中键'0', '1'对应全模型参数字典中键'2', '3'
        pretrained_dict['2'] = transfer_state_dict(pretrained_dict['0'],
                                                   model_dict['2'])
        pretrained_dict['3'] = transfer_state_dict(pretrained_dict['1'],
                                                   model_dict['3'])
        del pretrained_dict['0']  # 把原本预训练模型中的键值对删掉,以免错误地更新当前模型中的键值对
        del pretrained_dict['1']
        model_dict.update(pretrained_dict)  # 更新(合并)模型的参数
        self.model.load_state_dict(model_dict)
    else:
        print('Parameters are consistant, load state dict directly ...\n')
        model.load_state_dict(pretrained_dict)

    if with_gpu:
        model.to(torch.device("cuda:0"))
        model.parallelize()

    model.eval()
    return model
예제 #2
0
def load_model(model_path, with_gpu):
    logger.info("Loading checkpoint: {} ...".format(model_path))
    checkpoints = torch.load(model_path)
    if not checkpoints:
        raise RuntimeError('No checkpoint found.')
    config = checkpoints['config']
    state_dict = checkpoints['state_dict']

    model = FOTSModel(config)
    model.load_state_dict(state_dict)

    if with_gpu:
        model.to(torch.device("cuda:0"))
        model.parallelize()

    model.eval()
    return model