def main():
    # --------------------------------------model----------------------------------------
    if args.network is 'sphere20':
        model = net_triplet.sphere(type=20)
        model_eval = net_triplet.sphere(type=20)
    elif args.network is 'sphere64':
        model = net_triplet.sphere(type=64)
        model_eval = net_triplet.sphere(type=64)
    elif args.network is 'LResNet50E_IR':
        model = net_triplet.LResNet50E_IR()
        model_eval = net_triplet.LResNet50E_IR()
    else:
        raise ValueError("NOT SUPPORT NETWORK! ")

    model = torch.nn.DataParallel(model).to(device)
    model_eval = torch.nn.DataParallel(model_eval).to(device)
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    # ------------------------------------load image---------------------------------------
    #train set
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.CenterCrop(128),
        transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
        transforms.Normalize(mean=(0.5, 0.5, 0.5),
                             std=(0.5, 0.5,
                                  0.5))  # range [0.0, 1.0] -> [-1.0,1.0]
    ])
    train_loader = torch.utils.data.DataLoader(ContrastiveList(
        root=args.train_root,
        train_list=args.train_list,
        transform=train_transform),
                                               batch_size=args.batch_size,
                                               shuffle=True)
    # --------------------------------loss function and optimizer-----------------------------
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    criterion = ContrastiveLoss()
    # --------------------------------Load pretrained model parameters-----------------------------
    if args.pre_model:
        checkpoint = torch.load(args.pre_model)
        model.module.load_state_dict(checkpoint)
    # --------------------------------Train-----------------------------
    print("start training")
    for epoch in range(1, args.epochs + 1):
        train(train_loader, model, optimizer, criterion, epoch)
        model_name = args.save_path + str(epoch) + '.pth'
        model.module.save(model_name)
        eval(model_eval,
             epoch,
             model_name,
             args.eval_root,
             args.eval_list,
             device,
             batch_size=400,
             workers=12)
    print('Finished Training')
def train(train_loader, model, model_eval, optimizer, criterion, epoch):
    model.train()
    print_with_time('Epoch {} start training'.format(epoch))
    time_curr = time.time()
    loss_display = 0.0
    for batch_idx, (data_a, data_p, data_n) in enumerate(train_loader):
        iteration = (epoch - 1) * int(len(train_loader) / 2) + batch_idx
        # adjust_learning_rate(optimizer, iteration, args.step_size)
        data_a, data_p, data_n = data_a.to(device), data_p.to(
            device), data_n.to(device)
        out_a, out_p, out_n = model(data_a), model(data_p), model(data_n)
        # loss = F.triplet_margin_loss(out_a,out_p,out_n,margin=2,swap=True)
        loss = criterion(out_a, out_p, out_n)
        loss_display += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # if batch_idx % args.log_interval == 0 and batch_idx>0:
        time_used = time.time() - time_curr
        loss_display /= args.log_interval
        print_with_time(
            'Train Epoch: {} [{}/{} ({:.0f}%)]{}, Loss: {:.6f}, Elapsed time: {:.4f}s({} iters)'
            .format(epoch, batch_idx * len(data_a), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), iteration,
                    loss_display, time_used, args.log_interval))
        time_curr = time.time()
        loss_display = 0.0
        if iteration % 200 == 0 and iteration > 0:
            model_name = args.save_path + str(iteration) + '.pth'
            model.module.save(model_name)
            eval(model_eval,
                 epoch,
                 model_name,
                 args.eval_root,
                 args.eval_list,
                 device,
                 batch_size=400,
                 workers=12)
def main():
    # --------------------------------------model----------------------------------------
    if args.network is 'sphere20':
        model = net_msra.sphere(type=20)
        model_eval = net_msra.sphere(type=20)
    elif args.network is 'sphere64':
        model = net_msra.sphere(type=64)
        model_eval = net_msra.sphere(type=64)
    elif args.network is 'LResNet50E_IR':
        model = net_msra.LResNet50E_IR()
        model_eval = net_msra.sphere(type=64)
    else:
        raise ValueError("NOT SUPPORT NETWORK! ")
    # pretrain_model = torch.nn.DataParallel(pretrain_model).to(device)
    model = torch.nn.DataParallel(model).to(device)
    model_eval = torch.nn.DataParallel(model_eval).to(device)
    # print(model)
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    model.module.save(args.save_path + 'Sphere64_0_checkpoint.pth')

    # ------------------------------------load image---------------------------------------
    # train set
    with open(args.train_list) as json_file:
        data = json.load(json_file)
    train_list = data["image_names"]
    train_label = data["image_labels"]

    train_num = len(train_list)
    print('length of train Database: ' + str(len(train_list)))
    print('Number of Identities: ' + str(args.num_class))
    # --------------------------------Updated prototype matrix-----------------------------
    # Initialize the feature layer
    checkpoint = torch.load(args.pre_model)
    model.module.load_state_dict(checkpoint)

    data = sio.loadmat(args.new_weight)
    new_weight = data["weight"]
    new_weight_ = torch.from_numpy(new_weight)

    list_label = list(zip(train_list, train_label))
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.CenterCrop(128),
        transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
        transforms.Normalize(mean=(0.5, 0.5, 0.5),
                             std=(0.5, 0.5,
                                  0.5))  # range [0.0, 1.0] -> [-1.0,1.0]
    ])
    # --------------------------------Train-----------------------------
    for epoch in range(1, args.epochs + 1):
        print_with_time('Epoch {} start training'.format(epoch))
        random.shuffle(list_label)
        train_list[:], train_label[:] = zip(*list_label)

        # the whole datasets are divided into len(train_list)/Niter) parts to update
        for subpart in range(int(train_num / args.num_class)):
            print("subpart:", subpart)
            sub_train_list = train_list[subpart *
                                        args.num_class:(subpart + 1) *
                                        args.num_class]
            sub_train_label = train_label[subpart *
                                          args.num_class:(subpart + 1) *
                                          args.num_class]
            sub_weight = new_weight_[sub_train_label, :]

            classifier = {
                'MCP':
                layer.MarginCosineProduct(512, args.num_class).to(device),
                'AL':
                layer.AngleLinear(512, args.num_class, sub_weight).to(device),
                'L':
                torch.nn.Linear(512, args.num_class, bias=False).to(device),
                "ARC":
                layer.ArcMarginProduct(512,
                                       args.num_class,
                                       sub_weight,
                                       s=30,
                                       m=0.3,
                                       easy_margin=True).to(device),
                'DUM':
                layer.DumLoss(512, args.num_class, sub_weight).to(device),
                "DWI_AM":
                layer.DIAMSoftmaxLoss(512, args.num_class, sub_weight,
                                      device).to(device),
                "DWI_AL":
                layer.DWIAngleLinear(512, args.num_class,
                                     sub_weight).to(device),
            }[args.classifier_type]
            train_loader = torch.utils.data.DataLoader(
                PImageList(root=args.root_path,
                           train_root=sub_train_list,
                           train_label=sub_train_label,
                           transform=train_transform),
                batch_size=int(args.batch_size / 2),
                shuffle=True,
                num_workers=args.workers,
                pin_memory=False,
                drop_last=True)
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
            train(train_loader, model, classifier, optimizer, epoch)

            new_weight_[sub_train_label, :] = classifier.weight.data.cpu()
            temp = classifier.weight.data.cpu()
            # new_weight[sub_train_label,:] =
            for idx, name in enumerate(sub_train_label):
                new_weight[name, :] = temp[idx].numpy()
            # new_weight_ = torch.from_numpy(new_weight)
            model_path = args.save_path + 'DummpyFace_' + str(
                epoch) + '_checkpoint.pth'
            mat_path = args.save_path + 'DummpyFace_' + str(
                epoch) + '_checkpoint.mat'

            sio.savemat(mat_path, {"weight": new_weight, "label": train_label})
            model.module.save(model_path)
            eval(model_eval,
                 epoch,
                 model_path,
                 args.eval_root,
                 args.eval_list,
                 device,
                 batch_size=500,
                 workers=12)
    print('Finished Training')
def main():
    # --------------------------------------model----------------------------------------
    if args.network is 'sphere20':
        model = net.sphere(type=20)
        model_eval = net.sphere(type=20)
    elif args.network is 'sphere64':
        pre_model = net.sphere(type=64)
        model = net_msra.sphere(type=64)
        model_eval = net_msra.sphere(type=64)
    elif args.network is 'LResNet50E_IR':
        model = net.LResNet50E_IR()
        model_eval = net.LResNet50E_IR()
    else:
        raise ValueError("NOT SUPPORT NETWORK! ")

    pre_model = torch.nn.DataParallel(pre_model).to(device)

    model = torch.nn.DataParallel(model).to(device)
    model_eval = torch.nn.DataParallel(model_eval).to(device)
    # print(model)
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    model.module.save(args.save_path + '0_1.pth')

    # 512 is dimension of feature
    classifier = {
        'MCP': layer.MarginCosineProduct(512, args.num_class).to(device),
        'AL': layer.AngleLinear(512, args.num_class).to(device),
        'L': torch.nn.Linear(512, args.num_class, bias=False).to(device),
        # "ARC": layer.ArcMarginProduct(512, args.num_class, s=30, m=0.4, easy_margin=False).to(device),
    }[args.classifier_type]

    # classifier = torch.nn.DataParallel(classifier).to(device)
    # classifier.save(args.save_path + '0_2.pth')

    # ------------------------------------load image---------------------------------------

    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.CenterCrop(128),
        # transforms.RandomCrop(128),
        transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
        transforms.Normalize(mean=(0.5, 0.5, 0.5),
                             std=(0.5, 0.5,
                                  0.5))  # range [0.0, 1.0] -> [-1.0,1.0]
    ])
    train_loader = torch.utils.data.DataLoader(ImageList(
        root=args.root_path,
        fileList=args.train_list,
        transform=train_transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    print('length of train Database: ' + str(len(train_loader.dataset)))
    print('Number of Identities: ' + str(args.num_class))

    # --------------------------------loss function and optimizer-----------------------------
    criterion = torch.nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.SGD([{
        'params': model.parameters(),
        'lr': args.lr
    }, {
        'params': classifier.parameters(),
        'lr': args.lr * 10
    }],
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    #
    # optimizer = torch.optim.SGD([{'params': filter(lambda  p: p.requires_grad,model.parameters())}, {'params': classifier.parameters()}],
    #                             lr=args.lr,
    #                             momentum=args.momentum,
    #                             weight_decay=args.weight_decay)

    if args.pre_model:
        # print(args.pre_model)
        checkpoint = torch.load(args.pre_model)
        model.module.load_state_dict(checkpoint)

        # print(pre_model.state_dict())
        # model_dict = model.state_dict()
        # pre_model = {k:v for k,v in pre_model.state_dict().items() if k in model_dict}
        # model_dict.update(pre_model)
        # model.load_state_dict(model_dict)
        # print(model.state_dict())
        classifier_checkpoint = torch.load(
            "./models/checkpoint_msra_AL_2/3_2.pth")
        classifier.load_state_dict(classifier_checkpoint)

    # ----------------------------------------train----------------------------------------
    for epoch in range(args.start_epoch, args.epochs + 1):
        train(train_loader, model, classifier, criterion, optimizer, epoch)
        model_name = args.save_path + str(epoch) + '_1.pth'
        classifier_name = args.save_path + str(epoch) + '_2.pth'
        model.module.save(model_name)
        classifier.save(classifier_name)
        lfw_eval.eval(model_eval, model_name)
        evaluate_prototype.eval(model_eval,
                                epoch,
                                model_name,
                                args.eval_root,
                                args.eval_list,
                                device,
                                batch_size=400,
                                workers=12)
    print('Finished Training')