def main(args): # For fast training. cudnn.benchmark = True device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Redirect print to both console and log file if not args.evaluate: sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt')) print('log_dir=', args.logs_dir) # Print logs print(args) # Create data loaders dataset, num_classes, source_train_loader, target_train_loader, \ query_loader, gallery_loader = get_data(args.data_dir, args.source, args.target, args.height, args.width, args.batch_size, args.re, args.workers) # Create model model = models.create(args.arch, num_features=args.features, dropout=args.dropout, num_classes=num_classes) # Invariance learning model num_tgt = len(dataset.target_train) model_inv = InvNet(args.features, num_tgt, beta=args.inv_beta, knn=args.knn, alpha=args.inv_alpha) # Load from checkpoint start_epoch = 0 if args.resume: checkpoint = load_checkpoint(args.resume) model.load_state_dict(checkpoint['state_dict']) model_inv.load_state_dict(checkpoint['state_dict_inv']) start_epoch = checkpoint['epoch'] print("=> Start epoch {} " .format(start_epoch)) # Set model model = nn.DataParallel(model).to(device) model_inv = model_inv.to(device) # Evaluator evaluator = Evaluator(model) if args.evaluate: print("Test:") evaluator.evaluate(query_loader, gallery_loader, dataset.query, dataset.gallery, args.output_feature) return # Optimizer base_param_ids = set(map(id, model.module.base.parameters())) base_params_need_for_grad = filter(lambda p: p.requires_grad, model.module.base.parameters()) new_params = [p for p in model.parameters() if id(p) not in base_param_ids] param_groups = [ {'params': base_params_need_for_grad, 'lr_mult': 0.1}, {'params': new_params, 'lr_mult': 1.0}] optimizer = torch.optim.SGD(param_groups, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) # Trainer trainer = Trainer(model, model_inv, lmd=args.lmd, include_mmd=args.include_mmd) # Schedule learning rate def adjust_lr(epoch): step_size = args.epochs_decay lr = args.lr * (0.1 ** (epoch // step_size)) for g in optimizer.param_groups: g['lr'] = lr * g.get('lr_mult', 1) # Start training for epoch in range(start_epoch, args.epochs): adjust_lr(epoch) trainer.train(epoch, source_train_loader, target_train_loader, optimizer) save_checkpoint({ 'state_dict': model.module.state_dict(), 'state_dict_inv': model_inv.state_dict(), 'epoch': epoch + 1, }, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar')) print('\n * Finished epoch {:3d} \n'. format(epoch)) # Final test print('Test with best model:') evaluator = Evaluator(model) evaluator.evaluate(query_loader, gallery_loader, dataset.query, dataset.gallery, args.output_feature)
def main(args): fix(args.seed) # Redirect print to both console and log file if not args.evaluate: sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt')) print(args) # Create data loaders dataset, test_dataset, num_classes, source_train_loader, grid_query_loader, grid_gallery_loader,prid_query_loader, prid_gallery_loader,viper_query_loader, viper_gallery_loader, ilid_query_loader, ilid_gallery_loader = \ get_data(args.data_dir, args.height, args.width, args.batch_size, args.num_instance, args.re, args.workers) # Create model Encoder, Transfer, CamDis = models.create(args.arch, num_features=args.features, dropout=args.dropout, num_classes=num_classes) invNet = InvNet(args.features, num_classes, args.batch_size, beta=args.beta, knn=args.knn, alpha=args.alpha).cuda() # Load from checkpoint start_epoch = 0 if args.resume: checkpoint = load_checkpoint(args.resume) Encoder.load_state_dict(checkpoint['Encoder']) Transfer.load_state_dict(checkpoint['Transfer']) CamDis.load_state_dict(checkpoint['CamDis']) invNet.load_state_dict(checkpoint['InvNet']) start_epoch = checkpoint['epoch'] Encoder = Encoder.cuda() Transfer = Transfer.cuda() CamDis = CamDis.cuda() model = [Encoder, Transfer, CamDis] # Evaluator evaluator = Evaluator(model) if args.evaluate: # ----------------------------- v = evaluator.eval_viper(viper_query_loader, viper_gallery_loader, test_dataset.viper_query, test_dataset.viper_gallery, args.output_feature, seed=97) p = evaluator.eval_prid(prid_query_loader, prid_gallery_loader, test_dataset.prid_query, test_dataset.prid_gallery, args.output_feature, seed=40) g = evaluator.eval_grid(grid_query_loader, grid_gallery_loader, test_dataset.grid_query, test_dataset.grid_gallery, args.output_feature, seed=28) l = evaluator.eval_ilids(ilid_query_loader, test_dataset.ilid_query, args.output_feature, seed=24) # ----------------------------- criterion = [] criterion.append(nn.CrossEntropyLoss().cuda()) criterion.append(TripletLoss(margin=args.margin)) # Optimizer base_param_ids = set(map(id, Encoder.base.parameters())) new_params = [p for p in Encoder.parameters() if id(p) not in base_param_ids] param_groups = [ {'params': Encoder.base.parameters(), 'lr_mult': 0.1}, {'params': new_params, 'lr_mult': 1.0}] optimizer_Encoder = torch.optim.SGD(param_groups, lr=args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True) # ==== base_param_ids = set(map(id, Transfer.base.parameters())) new_params = [p for p in Transfer.parameters() if id(p) not in base_param_ids] param_groups = [ {'params': Transfer.base.parameters(), 'lr_mult': 0.1}, {'params': new_params, 'lr_mult': 1.0}] optimizer_Transfer = torch.optim.SGD(param_groups, lr=args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True) # ==== param_groups = [ {'params':CamDis.parameters(), 'lr_mult':1.0}, ] optimizer_Cam = torch.optim.SGD(param_groups, lr=args.lr,momentum=0.9, weight_decay=5e-4, nesterov=True) optimizer = [optimizer_Encoder, optimizer_Transfer, optimizer_Cam] # Trainer trainer = Trainer(model, criterion, InvNet=invNet) # Schedule learning rate def adjust_lr(epoch): step_size = 40 lr = args.lr * (0.1 ** ((epoch) // step_size)) for g in optimizer_Encoder.param_groups: g['lr'] = lr * g.get('lr_mult', 1) for g in optimizer_Transfer.param_groups: g['lr'] = lr * g.get('lr_mult', 1) for g in optimizer_Cam.param_groups: g['lr'] = lr * g.get('lr_mult', 1) # Start training for epoch in range(start_epoch, args.epochs): adjust_lr(epoch) trainer.train(epoch, source_train_loader, optimizer, args.tri_weight, args.adv_weight, args.mem_weight) save_checkpoint({ 'Encoder': Encoder.state_dict(), 'Transfer': Transfer.state_dict(), 'CamDis': CamDis.state_dict(), 'InvNet': invNet.state_dict(), 'epoch': epoch + 1, }, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar')) evaluator = Evaluator(model) print('\n * Finished epoch {:3d} \n'. format(epoch)) # Final test print('Test with best model:') evaluator = Evaluator(model) evaluator.evaluate(query_loader, gallery_loader, dataset.query, dataset.gallery, args.output_feature, args.rerank)