Example #1
0
def main():
    # --------------------------------------model----------------------------------------
    model = net_sia.LResNet50E_IR_Sia(is_gray=args.is_gray)
    model_eval = net_sia.LResNet50E_IR_Sia(is_gray=args.is_gray)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
    # 512 is dimension of feature
    classifier = {
        'MCP': layer.MarginCosineProduct(512, args.num_class),
        'AL': layer.AngleLinear(512, args.num_class),
        'L': torch.nn.Linear(512, args.num_class, bias=False)
    }[args.classifier_type]

    classifier.load_state_dict(torch.load(args.weight_fc))

    print(os.environ['CUDA_VISIBLE_DEVICES'], args.cuda)

    pretrained = torch.load(args.weight_model)
    pretrained_dict = pretrained['model_state_dict']
    model_dict = model.state_dict()
    model_eval_dict = model_eval.state_dict()
    for k, v in pretrained_dict.items():
        if k in model_dict:
            model_dict[k].copy_(v)

    del pretrained
    del pretrained_dict
    if args.resume:
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['model_state_dict'])
        # classifier ckpt only save model info
        classifier.load_state_dict(torch.load(args.resume_fc))
    print(model)
    model = torch.nn.DataParallel(model).to(device)
    model_eval = model_eval.to(device)
    classifier = classifier.to(device)

    args.run_name = utils.get_run_name()
    output_dir = os.path.join(args.save_path, args.run_name.split("_")[0])
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    # ------------------------------------load image---------------------------------------
    if args.is_gray:
        train_transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
            transforms.Normalize(mean=(0.5, ), std=(0.5, ))
        ])  # gray
    else:
        train_transform = transforms.Compose([
            transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
            transforms.Normalize(mean=(0.5, 0.5, 0.5),
                                 std=(0.5, 0.5,
                                      0.5))  # range [0.0, 1.0] -> [-1.0,1.0]
        ])
        valid_transform = transforms.Compose([
            transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
            transforms.Normalize(mean=(0.5, 0.5, 0.5),
                                 std=(0.5, 0.5,
                                      0.5))  # range [0.0, 1.0] -> [-1.0,1.0]
        ])
    train_loader = torch.utils.data.DataLoader(dset.ImageList(
        root=args.root_path,
        fileList=args.train_list,
        transform=train_transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=False,
                                               drop_last=True)
    val_loader = torch.utils.data.DataLoader(dset.ImageList(
        root=args.root_path,
        fileList=args.valid_list,
        transform=valid_transform),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=False,
                                             drop_last=False)

    print('length of train Database: ' + str(len(train_loader.dataset)) +
          ' Batches: ' + str(len(train_loader)))
    print('length of valid Database: ' + str(len(val_loader.dataset)) +
          ' Batches: ' + str(len(val_loader)))
    print('Number of Identities: ' + str(args.num_class))
    # Get a batch of training data, (img, img_occ, label)
    ''' 
    inputs, inputs_occ, imgPair, targets = next(iter(train_loader)) 
    out = torchvision.utils.make_grid(inputs)
    out_occ = torchvision.utils.make_grid(inputs_occ)
    
    mean = torch.tensor((0.5,0.5,0.5), dtype=torch.float32)
    std = torch.tensor((0.5,0.5,0.5), dtype=torch.float32)
    utils.imshow(out, mean, std, title=str(targets))
    plt.savefig(output_dir + '/train.png')
    utils.imshow(out_occ, mean, std, title=str(targets))
    plt.savefig(output_dir + '/train_occ.png')
    '''
    #---------------------------------------params setting-----------------------------------
    for name, param in model.named_parameters():
        if 'layer' in name or 'conv1' in name or 'bn1' in name or 'prelu1' in name:
            param.requires_grad = False
        else:
            param.requires_grad = True

    print("Params to learn:")
    params_to_update = []
    params_to_stay = []
    for name, param in model.named_parameters():
        if param.requires_grad == True:
            if 'sia' in name:
                params_to_update.append(param)
                print("Update \t", name)
            else:
                params_to_stay.append(param)
                print("Stay \t", name)

    for name, param in classifier.named_parameters():
        param.requires_grad = True
        params_to_stay.append(param)
        print("Stay \t", name)
    #--------------------------------loss function and optimizer-----------------------------
    cfg = configurations[args.config]
    criterion = torch.nn.CrossEntropyLoss().to(device)
    criterion2 = torch.nn.L1Loss(reduction='mean').to(device)
    optimizer = torch.optim.SGD([{
        'params': params_to_stay,
        'lr': 0,
        'weight_decay': 0,
        'momentum': 0
    }, {
        'params': params_to_update
    }],
                                lr=cfg['lr'],
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    start_epoch = 1
    if args.resume:
        optimizer.load_state_dict(checkpoint['optim_state_dict'])
        start_epoch = checkpoint['epoch']
        del checkpoint
    # ----------------------------------------train----------------------------------------
    save_ckpt(model, 0, optimizer, output_dir +
              '/CosFace_0_checkpoint.pth')  # Not resumed, pretrained~
    for epoch in range(start_epoch, cfg['epochs'] + 1):
        train(train_loader, model, classifier, criterion, criterion2,
              optimizer, epoch, cfg['step_size'], cfg['lr'])
        save_ckpt(model, epoch, optimizer,
                  output_dir + '/CosFace_' + str(epoch) + '_checkpoint.pth')
        print('Validating on valid set...')
        valid(val_loader, model_eval,
              output_dir + '/CosFace_' + str(epoch) + '_checkpoint.pth',
              classifier, criterion, criterion2)
    print('Finished Training')
def main():
    # --------------------------------------model----------------------------------------
    if args.network is 'sphere20':
        model = net.sphere(type=20, is_gray=args.is_gray)
        model_eval = net.sphere(type=20, is_gray=args.is_gray)
    elif args.network is 'sphere64':
        model = net.sphere(type=64, is_gray=args.is_gray)
        model_eval = net.sphere(type=64, is_gray=args.is_gray)
    elif args.network is 'LResNet50E_IR':
        model = net.LResNet50E_IR(is_gray=args.is_gray)
        model_eval = net.LResNet50E_IR(is_gray=args.is_gray)
    else:
        raise ValueError("NOT SUPPORT NETWORK! ")

    model = torch.nn.DataParallel(model).to(device)
    model_eval = model_eval.to(device)
    print(model)
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    model.module.save(args.save_path + 'CosFace_0_checkpoint.pth')

    # 512 is dimension of feature
    classifier = {
        'MCP': layer.MarginCosineProduct(512, args.num_class).to(device),
        'AL': layer.AngleLinear(512, args.num_class).to(device),
        'L': torch.nn.Linear(512, args.num_class, bias=False).to(device)
    }[args.classifier_type]

    # ------------------------------------load image---------------------------------------
    if args.is_gray:
        train_transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
            transforms.Normalize(mean=(0.5, ), std=(0.5, ))
        ])  # gray
    else:
        train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
            transforms.Normalize(mean=(0.5, 0.5, 0.5),
                                 std=(0.5, 0.5,
                                      0.5))  # range [0.0, 1.0] -> [-1.0,1.0]
        ])
    train_loader = torch.utils.data.DataLoader(ImageList(
        root=args.root_path,
        fileList=args.train_list,
        transform=train_transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    print('length of train Database: ' + str(len(train_loader.dataset)))
    print('Number of Identities: ' + str(args.num_class))

    # --------------------------------loss function and optimizer-----------------------------
    criterion = torch.nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.SGD([{
        'params': model.parameters()
    }, {
        'params': classifier.parameters()
    }],
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # ----------------------------------------train----------------------------------------
    # lfw_eval.eval(args.save_path + 'CosFace_0_checkpoint.pth')
    for epoch in range(1, args.epochs + 1):
        train(train_loader, model, classifier, criterion, optimizer, epoch)
        model.module.save(args.save_path + 'CosFace_' + str(epoch) +
                          '_checkpoint.pth')
        lfw_eval.eval(
            model_eval,
            args.save_path + 'CosFace_' + str(epoch) + '_checkpoint.pth',
            args.is_gray)
    print('Finished Training')