def __init__(self, datamanager, model, optimizer, weight_t=1, weight_x=1, scheduler=None, use_cpu=False, label_smooth=True): super(ImageCenterEngine, self).__init__(datamanager, model, optimizer, scheduler, use_cpu) self.weight_t = weight_t self.weight_x = weight_x # self.optimizer_cri = optimizer self.criterion_t = CenterLoss( num_classes=self.datamanager.num_train_pids, feat_dim=2048, use_gpu= self.use_gpu ) self.criterion_x = CrossEntropyLoss( num_classes=self.datamanager.num_train_pids, use_gpu=self.use_gpu, label_smooth=label_smooth )
def __init__(self, datamanager, model, optimizer, margin=0.3, weight_t=1, weight_x=1, scheduler=None, use_gpu=True, label_smooth=True): super(ImageTripletEngine, self).__init__(datamanager, model, optimizer, scheduler, use_gpu) self.weight_t = weight_t self.weight_x = weight_x self.criterion_t = TripletLoss(margin=margin) self.criterion_x = CrossEntropyLoss( num_classes=self.datamanager.num_train_pids, use_gpu=self.use_gpu, label_smooth=label_smooth) self.criterion_c1 = CenterLoss(num_classes=751, feat_dim=2048) self.criterion_c2 = CenterLoss(num_classes=751, feat_dim=512)
def main(): global args torch.manual_seed(args.seed) if not args.use_avai_gpus: os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices use_gpu = torch.cuda.is_available() if args.use_cpu: use_gpu = False log_name = 'log_test.txt' if args.evaluate else 'log_train.txt' sys.stdout = Logger(osp.join(args.save_dir, log_name)) print("==========\nArgs:{}\n==========".format(args)) if use_gpu: print("Currently using GPU {}".format(args.gpu_devices)) cudnn.benchmark = True torch.cuda.manual_seed_all(args.seed) else: print("Currently using CPU, however, GPU is highly recommended") print("Initializing image data manager") dm = ImageDataManager(use_gpu, **image_dataset_kwargs(args)) trainloader, testloader_dict = dm.return_dataloaders() print("Initializing model: {}".format(args.arch)) model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent'}, use_gpu=use_gpu) print("Model size: {:.3f} M".format(count_num_param(model))) criterion = CrossEntropyLoss(num_classes=dm.num_train_pids, use_gpu=use_gpu, label_smooth=args.label_smooth) center_loss1 = CenterLoss(num_classes=2, feat_dim=512, use_gpu=use_gpu) center_loss2 = CenterLoss(num_classes=2, feat_dim=512, use_gpu=use_gpu) center_loss3 = CenterLoss(num_classes=1452, feat_dim=512, use_gpu=use_gpu) center_loss4 = CenterLoss(num_classes=1452, feat_dim=512, use_gpu=use_gpu) params = list(center_loss1.parameters()) + list(model.parameters()) + list( center_loss2.parameters()) + list(center_loss3.parameters()) + list( center_loss4.parameters()) optimizer = init_optimizer(params, **optimizer_kwargs(args)) scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.stepsize, gamma=args.gamma) if args.load_weights and check_isfile(args.load_weights): # load pretrained weights but ignore layers that don't match in size if (use_gpu): checkpoint = torch.load(args.load_weights) else: checkpoint = torch.load(args.load_weights, map_location='cpu') pretrain_dict = checkpoint['state_dict'] model_dict = model.state_dict() pretrain_dict = { k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size() } model_dict.update(pretrain_dict) model.load_state_dict(model_dict) print("Loaded pretrained weights from '{}'".format(args.load_weights)) if args.resume and check_isfile(args.resume): checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['state_dict']) args.start_epoch = checkpoint['epoch'] + 1 print("Loaded checkpoint from '{}'".format(args.resume)) print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, checkpoint['rank1'])) if use_gpu: model = nn.DataParallel(model).cuda() if args.evaluate: print("Evaluate only") for name in args.target_names: print("Evaluating {} ...".format(name)) queryloader = testloader_dict[name]['query'] galleryloader = testloader_dict[name]['gallery'] distmat = test(model, queryloader, galleryloader, use_gpu, return_distmat=True) if args.visualize_ranks: visualize_ranked_results(distmat, dm.return_testdataset_by_name(name), save_dir=osp.join( args.save_dir, 'ranked_results', name), topk=20) return start_time = time.time() ranklogger = RankLogger(args.source_names, args.target_names) train_time = 0 print("=> Start training") if args.fixbase_epoch > 0: print( "Train {} for {} epochs while keeping other layers frozen".format( args.open_layers, args.fixbase_epoch)) initial_optim_state = optimizer.state_dict() for epoch in range(args.fixbase_epoch): start_train_time = time.time() train(epoch, model, criterion, center_loss1, center_loss2, center_loss3, center_loss4, optimizer, trainloader, use_gpu, fixbase=True) train_time += round(time.time() - start_train_time) print("Done. All layers are open to train for {} epochs".format( args.max_epoch)) optimizer.load_state_dict(initial_optim_state) for epoch in range(args.start_epoch, args.max_epoch): start_train_time = time.time() train(epoch, model, criterion, center_loss1, center_loss2, center_loss3, center_loss4, optimizer, trainloader, use_gpu) train_time += round(time.time() - start_train_time) scheduler.step() if (epoch + 1) > args.start_eval and args.eval_freq > 0 and ( epoch + 1) % args.eval_freq == 0 or (epoch + 1) == args.max_epoch: print("=> Test") for name in args.target_names: print("Evaluating {} ...".format(name)) queryloader = testloader_dict[name]['query'] galleryloader = testloader_dict[name]['gallery'] rank1 = test(model, queryloader, galleryloader, use_gpu) ranklogger.write(name, epoch + 1, rank1) if use_gpu: state_dict = model.module.state_dict() else: state_dict = model.state_dict() save_checkpoint( { 'state_dict': state_dict, 'rank1': rank1, 'epoch': epoch, }, False, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar')) elapsed = round(time.time() - start_time) elapsed = str(datetime.timedelta(seconds=elapsed)) train_time = str(datetime.timedelta(seconds=train_time)) print( "Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.". format(elapsed, train_time)) ranklogger.show_summary()
class ImageCenterEngine(engine.Engine): """Triplet-loss engine for image-reid. """ def __init__(self, datamanager, model, optimizer, weight_t=1, weight_x=1, scheduler=None, use_cpu=False, label_smooth=True): super(ImageCenterEngine, self).__init__(datamanager, model, optimizer, scheduler, use_cpu) self.weight_t = weight_t self.weight_x = weight_x # self.optimizer_cri = optimizer self.criterion_t = CenterLoss( num_classes=self.datamanager.num_train_pids, feat_dim=2048, use_gpu= self.use_gpu ) self.criterion_x = CrossEntropyLoss( num_classes=self.datamanager.num_train_pids, use_gpu=self.use_gpu, label_smooth=label_smooth ) def train(self, epoch, trainloader, fixbase=False, open_layers=None, print_freq=10): """Trains the model for one epoch on source datasets using hard mining triplet loss. Args: epoch (int): current epoch. trainloader (Dataloader): training dataloader. fixbase (bool, optional): whether to fix base layers. Default is False. open_layers (str or list, optional): layers open for training. print_freq (int, optional): print frequency. Default is 10. """ losses_t = AverageMeter() losses_x = AverageMeter() accs = AverageMeter() batch_time = AverageMeter() data_time = AverageMeter() self.model.train() if fixbase and (open_layers is not None): open_specified_layers(self.model, open_layers) else: open_all_layers(self.model) end = time.time() for batch_idx, data in enumerate(trainloader): data_time.update(time.time() - end) imgs, pids = self._parse_data_for_train(data) if self.use_gpu: imgs = imgs.cuda() pids = pids.cuda() self.optimizer.zero_grad() # self.optimizer_cri.zero_grad() outputs, features = self.model(imgs) loss_t = self._compute_loss(self.criterion_t, features, pids) loss_x = self._compute_loss(self.criterion_x, outputs, pids) loss = self.weight_t * loss_t + self.weight_x * loss_x loss.backward() for param in self.criterion_t.parameters(): param.grad.data *= (0.5 / self.weight_t) # self.optimizer_cri.step() self.optimizer.step() batch_time.update(time.time() - end) losses_t.update(loss_t.item(), pids.size(0)) losses_x.update(loss_x.item(), pids.size(0)) accs.update(metrics.accuracy(outputs, pids)[0].item()) if (batch_idx + 1) % print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Center {loss_t.val:.4f} ({loss_t.avg:.4f})\t' 'Softmax {loss_x.val:.4f} ({loss_x.avg:.4f})\t' 'Acc {acc.val:.2f} ({acc.avg:.2f})\t'.format( epoch + 1, batch_idx + 1, len(trainloader), batch_time=batch_time, data_time=data_time, loss_t=losses_t, loss_x=losses_x, acc=accs)) end = time.time() if (self.scheduler is not None) and (not fixbase): self.scheduler.step()