def main():
    # --------------------------------------model----------------------------------------
    if args.network is 'sphere20':
        model = net_msra.sphere(type=20)
        model_eval = net_msra.sphere(type=20)
    elif args.network is 'sphere64':
        model = net_msra.sphere(type=64)
        model_eval = net_msra.sphere(type=64)
    elif args.network is 'LResNet50E_IR':
        model = net_msra.LResNet50E_IR()
        model_eval = net_msra.sphere(type=64)
    else:
        raise ValueError("NOT SUPPORT NETWORK! ")
    # pretrain_model = torch.nn.DataParallel(pretrain_model).to(device)
    model = torch.nn.DataParallel(model).to(device)
    model_eval = torch.nn.DataParallel(model_eval).to(device)
    # print(model)
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    model.module.save(args.save_path + 'Sphere64_0_checkpoint.pth')

    # ------------------------------------load image---------------------------------------
    # train set
    with open(args.train_list) as json_file:
        data = json.load(json_file)
    train_list = data["image_names"]
    train_label = data["image_labels"]

    train_num = len(train_list)
    print('length of train Database: ' + str(len(train_list)))
    print('Number of Identities: ' + str(args.num_class))
    # --------------------------------Updated prototype matrix-----------------------------
    # Initialize the feature layer
    checkpoint = torch.load(args.pre_model)
    model.module.load_state_dict(checkpoint)

    data = sio.loadmat(args.new_weight)
    new_weight = data["weight"]
    new_weight_ = torch.from_numpy(new_weight)

    list_label = list(zip(train_list, train_label))
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.CenterCrop(128),
        transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
        transforms.Normalize(mean=(0.5, 0.5, 0.5),
                             std=(0.5, 0.5,
                                  0.5))  # range [0.0, 1.0] -> [-1.0,1.0]
    ])
    # --------------------------------Train-----------------------------
    for epoch in range(1, args.epochs + 1):
        print_with_time('Epoch {} start training'.format(epoch))
        random.shuffle(list_label)
        train_list[:], train_label[:] = zip(*list_label)

        # the whole datasets are divided into len(train_list)/Niter) parts to update
        for subpart in range(int(train_num / args.num_class)):
            print("subpart:", subpart)
            sub_train_list = train_list[subpart *
                                        args.num_class:(subpart + 1) *
                                        args.num_class]
            sub_train_label = train_label[subpart *
                                          args.num_class:(subpart + 1) *
                                          args.num_class]
            sub_weight = new_weight_[sub_train_label, :]

            classifier = {
                'MCP':
                layer.MarginCosineProduct(512, args.num_class).to(device),
                'AL':
                layer.AngleLinear(512, args.num_class, sub_weight).to(device),
                'L':
                torch.nn.Linear(512, args.num_class, bias=False).to(device),
                "ARC":
                layer.ArcMarginProduct(512,
                                       args.num_class,
                                       sub_weight,
                                       s=30,
                                       m=0.3,
                                       easy_margin=True).to(device),
                'DUM':
                layer.DumLoss(512, args.num_class, sub_weight).to(device),
                "DWI_AM":
                layer.DIAMSoftmaxLoss(512, args.num_class, sub_weight,
                                      device).to(device),
                "DWI_AL":
                layer.DWIAngleLinear(512, args.num_class,
                                     sub_weight).to(device),
            }[args.classifier_type]
            train_loader = torch.utils.data.DataLoader(
                PImageList(root=args.root_path,
                           train_root=sub_train_list,
                           train_label=sub_train_label,
                           transform=train_transform),
                batch_size=int(args.batch_size / 2),
                shuffle=True,
                num_workers=args.workers,
                pin_memory=False,
                drop_last=True)
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
            train(train_loader, model, classifier, optimizer, epoch)

            new_weight_[sub_train_label, :] = classifier.weight.data.cpu()
            temp = classifier.weight.data.cpu()
            # new_weight[sub_train_label,:] =
            for idx, name in enumerate(sub_train_label):
                new_weight[name, :] = temp[idx].numpy()
            # new_weight_ = torch.from_numpy(new_weight)
            model_path = args.save_path + 'DummpyFace_' + str(
                epoch) + '_checkpoint.pth'
            mat_path = args.save_path + 'DummpyFace_' + str(
                epoch) + '_checkpoint.mat'

            sio.savemat(mat_path, {"weight": new_weight, "label": train_label})
            model.module.save(model_path)
            eval(model_eval,
                 epoch,
                 model_path,
                 args.eval_root,
                 args.eval_list,
                 device,
                 batch_size=500,
                 workers=12)
    print('Finished Training')
def main():
    # --------------------------------------model----------------------------------------
    if args.network is 'sphere20':
        model = net.sphere(type=20)
        model_eval = net.sphere(type=20)
    elif args.network is 'sphere64':
        pre_model = net.sphere(type=64)
        model = net_msra.sphere(type=64)
        model_eval = net_msra.sphere(type=64)
    elif args.network is 'LResNet50E_IR':
        model = net.LResNet50E_IR()
        model_eval = net.LResNet50E_IR()
    else:
        raise ValueError("NOT SUPPORT NETWORK! ")

    pre_model = torch.nn.DataParallel(pre_model).to(device)

    model = torch.nn.DataParallel(model).to(device)
    model_eval = torch.nn.DataParallel(model_eval).to(device)
    # print(model)
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    model.module.save(args.save_path + '0_1.pth')

    # 512 is dimension of feature
    classifier = {
        'MCP': layer.MarginCosineProduct(512, args.num_class).to(device),
        'AL': layer.AngleLinear(512, args.num_class).to(device),
        'L': torch.nn.Linear(512, args.num_class, bias=False).to(device),
        # "ARC": layer.ArcMarginProduct(512, args.num_class, s=30, m=0.4, easy_margin=False).to(device),
    }[args.classifier_type]

    # classifier = torch.nn.DataParallel(classifier).to(device)
    # classifier.save(args.save_path + '0_2.pth')

    # ------------------------------------load image---------------------------------------

    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.CenterCrop(128),
        # transforms.RandomCrop(128),
        transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
        transforms.Normalize(mean=(0.5, 0.5, 0.5),
                             std=(0.5, 0.5,
                                  0.5))  # range [0.0, 1.0] -> [-1.0,1.0]
    ])
    train_loader = torch.utils.data.DataLoader(ImageList(
        root=args.root_path,
        fileList=args.train_list,
        transform=train_transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    print('length of train Database: ' + str(len(train_loader.dataset)))
    print('Number of Identities: ' + str(args.num_class))

    # --------------------------------loss function and optimizer-----------------------------
    criterion = torch.nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.SGD([{
        'params': model.parameters(),
        'lr': args.lr
    }, {
        'params': classifier.parameters(),
        'lr': args.lr * 10
    }],
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    #
    # optimizer = torch.optim.SGD([{'params': filter(lambda  p: p.requires_grad,model.parameters())}, {'params': classifier.parameters()}],
    #                             lr=args.lr,
    #                             momentum=args.momentum,
    #                             weight_decay=args.weight_decay)

    if args.pre_model:
        # print(args.pre_model)
        checkpoint = torch.load(args.pre_model)
        model.module.load_state_dict(checkpoint)

        # print(pre_model.state_dict())
        # model_dict = model.state_dict()
        # pre_model = {k:v for k,v in pre_model.state_dict().items() if k in model_dict}
        # model_dict.update(pre_model)
        # model.load_state_dict(model_dict)
        # print(model.state_dict())
        classifier_checkpoint = torch.load(
            "./models/checkpoint_msra_AL_2/3_2.pth")
        classifier.load_state_dict(classifier_checkpoint)

    # ----------------------------------------train----------------------------------------
    for epoch in range(args.start_epoch, args.epochs + 1):
        train(train_loader, model, classifier, criterion, optimizer, epoch)
        model_name = args.save_path + str(epoch) + '_1.pth'
        classifier_name = args.save_path + str(epoch) + '_2.pth'
        model.module.save(model_name)
        classifier.save(classifier_name)
        lfw_eval.eval(model_eval, model_name)
        evaluate_prototype.eval(model_eval,
                                epoch,
                                model_name,
                                args.eval_root,
                                args.eval_list,
                                device,
                                batch_size=400,
                                workers=12)
    print('Finished Training')
Exemple #3
0
    label = np.vstack((predicts_T[1], predicts_F[1]))
    predict = np.reshape(predict, -1, 1)
    label = np.reshape(label, -1, 1)

    TPR_001, TPR_0001, TPR_00001, TPR_000001 = evaluate_xjc(predict, label)
    line_vote_acc = "{}\t{:.6}\t{:.6}\t{:.6f}\t{:.6}\n".format(
        epoch, TPR_001, TPR_0001, TPR_00001, TPR_000001)

    f2 = open("TPR_NJN_DWI_A-softmax.txt", "a+")
    f2.write(line_vote_acc)


if __name__ == '__main__':
    device = torch.device("cuda")
    model = net_msra.sphere(type=64)
    model = torch.nn.DataParallel(model).cuda()
    model_path = '/home/yanghuiwen/project/3_low_shot_learning/Large-scale_Bisample_Learning_on_ID_vs_Spot_Face_Recognition/models/checkpoint_NJN_RP/DummpyFace_5_checkpoint.pth'
    eval_root = '/home3/yhw_datasets/face_recognition/NJN_crop/test'
    eval_list = '/home/yanghuiwen/project/3_low_shot_learning/Large-scale_Bisample_Learning_on_ID_vs_Spot_Face_Recognition/Data/NJN_Random_prototype_test.json'
    epoch = 4000
    result = eval(model,
                  epoch,
                  model_path,
                  eval_root,
                  eval_list,
                  device,
                  batch_size=400,
                  workers=12)

    # np.savetxt("result.txt",result,'%s')
            # print(target)
            input_id, input_spot, target = input_id.to(device), input_spot.to(device), target.to(device)
            output_id = model(input_id)
            if args.prototype == "ID":
                output = output_id
            elif args.prototype == "AVG":
                output_spot = model(input_spot)
                output = (output_id + output_spot) / 2  # w = (Wid + Wspot)/2
            for idx in range(args.batch_size):
                tmp = output[idx, :].cpu()
                tmp1 = tmp / tmp.norm(p=2)
                new_weight[target[idx], :] = tmp1
    return new_weight

if args.network is 'sphere20':
    model = net_msra.sphere(type=20)
    model_eval = net_msra.sphere(type=20)
elif args.network is 'sphere64':
    model = net_msra.sphere(type=64)
    model_eval = net_msra.sphere(type=64)
elif args.network is 'LResNet50E_IR':
    model = net_msra.LResNet50E_IR()
    model_eval = net_msra.sphere(type=64)
else:
    raise ValueError("NOT SUPPORT NETWORK! ")

model = torch.nn.DataParallel(model).to(device)

with open(args.train_list) as json_file:
    data = json.load(json_file)
train_list =data["image_names"]