Beispiel #1
0
def test_pretrain_model():
    device = torch.device("cuda")

    # 参数
    opt = config.MobileNetV3Config()

    # 验证集
    identity_list = dataset.get_lfw_list(opt.lfw_test_list)
    lfw_img_paths = [
        os.path.join(opt.lfw_root, each) for each in identity_list
    ]  # 所有图片的路径

    # 加载模型
    model = mobileNetV3.MobileNetV3(n_class=opt.embedding,
                                    input_size=opt.input_shape[2],
                                    dropout=opt.dropout_rate)
    model.to(device)
    model = DataParallel(model)

    # 加载预训练的模型
    state = torch.load(MODEL)
    model.load_state_dict(state['state_dict'])

    # 用LFW数据集测试
    accuracy, threshold = lfw_test(model, lfw_img_paths, identity_list, opt)
Beispiel #2
0
def test_inference_once():
    device = torch.device("cuda")

    # 参数
    opt = config.MobileNetV3Config()

    # 加载模型
    model = mobileNetV3_MixNet.MobileNetV3_MixNet(n_class=opt.embedding, input_size=opt.input_shape[2], dropout=opt.dropout_rate)
    model.to(device)
    model = DataParallel(model)

    x_image = torch.randn(1, 3, 224, 224).to(device)
    y = model(x_image)
    print(y)
Beispiel #3
0
    txt_info = "{} {} {} {}\n".format(epoch, iter, lr, iters)
    lr_filename = os.path.join(save_path, pretrain_info_name)
    with open(lr_filename, 'a') as fout:
        fout.write(txt_info)

    return save_name


if __name__ == '__main__':

    print("You have ", torch.cuda.device_count(), " GPUs")
    device = torch.device("cuda")

    # 参数
    opt = config.MobileNetV3Config()

    # 设置路径--保存训练产生的数据
    date = time.strftime("%Y-%m-%d", time.localtime())
    save_path = os.path.join(opt.checkpoints_path, date)  # 保存的文件夹路径
    os.makedirs(save_path, exist_ok=True)
    log_filename = os.path.join(save_path, 'Console_Log.txt')  # 日志路径

    # 验证集
    identity_list = dataset.get_lfw_list(opt.lfw_test_list)
    lfw_img_paths = [os.path.join(opt.lfw_root, each) for each in identity_list]  # 所有图片的路径

    # 读取训练数据集
    train_dataset = dataset.Dataset(opt.train_root, opt.path_split, phase='train', input_shape=opt.input_shape)
    trainloader = data.DataLoader(train_dataset,
                                  batch_size=opt.train_batch_size,