def main(): # --------------------------------------model---------------------------------------- model = net_sia.LResNet50E_IR_Sia(is_gray=args.is_gray) model_eval = net_sia.LResNet50E_IR_Sia(is_gray=args.is_gray) os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus # 512 is dimension of feature classifier = { 'MCP': layer.MarginCosineProduct(512, args.num_class), 'AL': layer.AngleLinear(512, args.num_class), 'L': torch.nn.Linear(512, args.num_class, bias=False) }[args.classifier_type] classifier.load_state_dict(torch.load(args.weight_fc)) print(os.environ['CUDA_VISIBLE_DEVICES'], args.cuda) pretrained = torch.load(args.weight_model) pretrained_dict = pretrained['model_state_dict'] model_dict = model.state_dict() model_eval_dict = model_eval.state_dict() for k, v in pretrained_dict.items(): if k in model_dict: model_dict[k].copy_(v) del pretrained del pretrained_dict if args.resume: checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['model_state_dict']) # classifier ckpt only save model info classifier.load_state_dict(torch.load(args.resume_fc)) print(model) model = torch.nn.DataParallel(model).to(device) model_eval = model_eval.to(device) classifier = classifier.to(device) args.run_name = utils.get_run_name() output_dir = os.path.join(args.save_path, args.run_name.split("_")[0]) if not os.path.exists(output_dir): os.makedirs(output_dir) # ------------------------------------load image--------------------------------------- if args.is_gray: train_transform = transforms.Compose([ transforms.Grayscale(), transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] transforms.Normalize(mean=(0.5, ), std=(0.5, )) ]) # gray else: train_transform = transforms.Compose([ transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0] ]) valid_transform = transforms.Compose([ transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0] ]) train_loader = torch.utils.data.DataLoader(dset.ImageList( root=args.root_path, fileList=args.train_list, transform=train_transform), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=False, drop_last=True) val_loader = torch.utils.data.DataLoader(dset.ImageList( root=args.root_path, fileList=args.valid_list, transform=valid_transform), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=False, drop_last=False) print('length of train Database: ' + str(len(train_loader.dataset)) + ' Batches: ' + str(len(train_loader))) print('length of valid Database: ' + str(len(val_loader.dataset)) + ' Batches: ' + str(len(val_loader))) print('Number of Identities: ' + str(args.num_class)) # Get a batch of training data, (img, img_occ, label) ''' inputs, inputs_occ, imgPair, targets = next(iter(train_loader)) out = torchvision.utils.make_grid(inputs) out_occ = torchvision.utils.make_grid(inputs_occ) mean = torch.tensor((0.5,0.5,0.5), dtype=torch.float32) std = torch.tensor((0.5,0.5,0.5), dtype=torch.float32) utils.imshow(out, mean, std, title=str(targets)) plt.savefig(output_dir + '/train.png') utils.imshow(out_occ, mean, std, title=str(targets)) plt.savefig(output_dir + '/train_occ.png') ''' #---------------------------------------params setting----------------------------------- for name, param in model.named_parameters(): if 'layer' in name or 'conv1' in name or 'bn1' in name or 'prelu1' in name: param.requires_grad = False else: param.requires_grad = True print("Params to learn:") params_to_update = [] params_to_stay = [] for name, param in model.named_parameters(): if param.requires_grad == True: if 'sia' in name: params_to_update.append(param) print("Update \t", name) else: params_to_stay.append(param) print("Stay \t", name) for name, param in classifier.named_parameters(): param.requires_grad = True params_to_stay.append(param) print("Stay \t", name) #--------------------------------loss function and optimizer----------------------------- cfg = configurations[args.config] criterion = torch.nn.CrossEntropyLoss().to(device) criterion2 = torch.nn.L1Loss(reduction='mean').to(device) optimizer = torch.optim.SGD([{ 'params': params_to_stay, 'lr': 0, 'weight_decay': 0, 'momentum': 0 }, { 'params': params_to_update }], lr=cfg['lr'], momentum=args.momentum, weight_decay=args.weight_decay) start_epoch = 1 if args.resume: optimizer.load_state_dict(checkpoint['optim_state_dict']) start_epoch = checkpoint['epoch'] del checkpoint # ----------------------------------------train---------------------------------------- save_ckpt(model, 0, optimizer, output_dir + '/CosFace_0_checkpoint.pth') # Not resumed, pretrained~ for epoch in range(start_epoch, cfg['epochs'] + 1): train(train_loader, model, classifier, criterion, criterion2, optimizer, epoch, cfg['step_size'], cfg['lr']) save_ckpt(model, epoch, optimizer, output_dir + '/CosFace_' + str(epoch) + '_checkpoint.pth') print('Validating on valid set...') valid(val_loader, model_eval, output_dir + '/CosFace_' + str(epoch) + '_checkpoint.pth', classifier, criterion, criterion2) print('Finished Training')
def main(): # --------------------------------------model---------------------------------------- if args.network is 'sphere20': model = net.sphere(type=20, is_gray=args.is_gray) model_eval = net.sphere(type=20, is_gray=args.is_gray) elif args.network is 'sphere64': model = net.sphere(type=64, is_gray=args.is_gray) model_eval = net.sphere(type=64, is_gray=args.is_gray) elif args.network is 'LResNet50E_IR': model = net.LResNet50E_IR(is_gray=args.is_gray) model_eval = net.LResNet50E_IR(is_gray=args.is_gray) else: raise ValueError("NOT SUPPORT NETWORK! ") model = torch.nn.DataParallel(model).to(device) model_eval = model_eval.to(device) print(model) if not os.path.exists(args.save_path): os.makedirs(args.save_path) model.module.save(args.save_path + 'CosFace_0_checkpoint.pth') # 512 is dimension of feature classifier = { 'MCP': layer.MarginCosineProduct(512, args.num_class).to(device), 'AL': layer.AngleLinear(512, args.num_class).to(device), 'L': torch.nn.Linear(512, args.num_class, bias=False).to(device) }[args.classifier_type] # ------------------------------------load image--------------------------------------- if args.is_gray: train_transform = transforms.Compose([ transforms.Grayscale(), transforms.RandomHorizontalFlip(), transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] transforms.Normalize(mean=(0.5, ), std=(0.5, )) ]) # gray else: train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0] ]) train_loader = torch.utils.data.DataLoader(ImageList( root=args.root_path, fileList=args.train_list, transform=train_transform), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) print('length of train Database: ' + str(len(train_loader.dataset))) print('Number of Identities: ' + str(args.num_class)) # --------------------------------loss function and optimizer----------------------------- criterion = torch.nn.CrossEntropyLoss().to(device) optimizer = torch.optim.SGD([{ 'params': model.parameters() }, { 'params': classifier.parameters() }], lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # ----------------------------------------train---------------------------------------- # lfw_eval.eval(args.save_path + 'CosFace_0_checkpoint.pth') for epoch in range(1, args.epochs + 1): train(train_loader, model, classifier, criterion, optimizer, epoch) model.module.save(args.save_path + 'CosFace_' + str(epoch) + '_checkpoint.pth') lfw_eval.eval( model_eval, args.save_path + 'CosFace_' + str(epoch) + '_checkpoint.pth', args.is_gray) print('Finished Training')