def get_transform(mode, args): if mode == 'train': transform = transforms.Compose([ A.RandomSizedCrop(size=224, consistent=True, bottom_area=0.2), A.Scale(args.img_dim), A.ColorJitter(0.4, 0.4, 0.4, 0.1, p=0.3, consistent=True), A.ToTensor(), ]) elif mode == 'val' or mode == 'test': transform = transforms.Compose([ A.RandomSizedCrop(size=224, consistent=True, bottom_area=0.2), A.Scale(args.img_dim), A.ToTensor(), ]) return transform
def test_retrieval(model, criterion, transforms_cuda, device, epoch, args): accuracy = [AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()] model.eval() def tr(x): B = x.size(0) assert B == 1 test_sample = x.size(2) // (args.seq_len * args.num_seq) return transforms_cuda(x)\ .view(3,test_sample,args.num_seq,args.seq_len,args.img_dim,args.img_dim).permute(1,2,0,3,4,5) with torch.no_grad(): transform = transforms.Compose([ A.CenterCrop(size=(224, 224)), A.Scale(size=(args.img_dim, args.img_dim)), A.ColorJitter(0.2, 0.2, 0.2, 0.1, p=0.3, consistent=True), A.ToTensor() ]) if args.dataset == 'ucf101': d_class = UCF101LMDB elif args.dataset == 'ucf101-f': d_class = UCF101Flow_LMDB elif args.dataset == 'hmdb51': d_class = HMDB51LMDB elif args.dataset == 'hmdb51-f': d_class = HMDB51Flow_LMDB train_dataset = d_class(mode='train', transform=transform, num_frames=args.num_seq * args.seq_len, ds=args.ds, which_split=1, return_label=True, return_path=True) print('train dataset size: %d' % len(train_dataset)) test_dataset = d_class(mode='test', transform=transform, num_frames=args.num_seq * args.seq_len, ds=args.ds, which_split=1, return_label=True, return_path=True) print('test dataset size: %d' % len(test_dataset)) train_sampler = data.Sequential(train_dataset) test_sampler = data.Sequential(test_dataset) train_loader = data.DataLoader(train_dataset, batch_size=1, sampler=train_sampler, shuffle=False, num_workers=args.workers, pin_memory=True) test_loader = data.DataLoader(test_dataset, batch_size=1, sampler=test_sampler, shuffle=False, num_workers=args.workers, pin_memory=True) if args.dirname is None: dirname = 'feature' else: dirname = args.dirname if os.path.exists( os.path.join(os.path.dirname(args.test), dirname, '%s_test_feature.pth.tar' % args.dataset)): test_feature = torch.load( os.path.join(os.path.dirname(args.test), dirname, '%s_test_feature.pth.tar' % args.dataset)).to(device) test_label = torch.load( os.path.join(os.path.dirname(args.test), dirname, '%s_test_label.pth.tar' % args.dataset)).to(device) else: try: os.makedirs(os.path.join(os.path.dirname(args.test), dirname)) except: pass print('Computing test set feature ... ') test_feature = None test_label = [] test_vname = [] sample_id = 0 for idx, (input_seq, target) in tqdm(enumerate(test_loader), total=len(test_loader)): B = 1 input_seq = input_seq.to(device, non_blocking=True) input_seq = tr(input_seq) current_target, vname = target current_target = current_target.to(device, non_blocking=True) test_sample = input_seq.size(0) if args.other is not None: input_seq = input_seq.squeeze(1) logit, feature = model(input_seq) if test_feature is None: test_feature = torch.zeros(len(test_dataset), feature.size(-1), device=feature.device) if args.other is not None: test_feature[sample_id, :] = feature.mean(0) else: test_feature[sample_id, :] = feature[:, -1, :].mean(0) test_label.append(current_target) test_vname.append(vname) sample_id += 1 print(test_feature.size()) # test_feature = torch.stack(test_feature, dim=0) test_label = torch.cat(test_label, dim=0) torch.save( test_feature, os.path.join(os.path.dirname(args.test), dirname, '%s_test_feature.pth.tar' % args.dataset)) torch.save( test_label, os.path.join(os.path.dirname(args.test), dirname, '%s_test_label.pth.tar' % args.dataset)) with open( os.path.join(os.path.dirname(args.test), dirname, '%s_test_vname.pkl' % args.dataset), 'wb') as fp: pickle.dump(test_vname, fp) if os.path.exists( os.path.join(os.path.dirname(args.test), dirname, '%s_train_feature.pth.tar' % args.dataset)): train_feature = torch.load( os.path.join(os.path.dirname(args.test), dirname, '%s_train_feature.pth.tar' % args.dataset)).to(device) train_label = torch.load( os.path.join(os.path.dirname(args.test), dirname, '%s_train_label.pth.tar' % args.dataset)).to(device) else: print('Computing train set feature ... ') train_feature = None train_label = [] train_vname = [] sample_id = 0 for idx, (input_seq, target) in tqdm(enumerate(train_loader), total=len(train_loader)): B = 1 input_seq = input_seq.to(device, non_blocking=True) input_seq = tr(input_seq) current_target, vname = target current_target = current_target.to(device, non_blocking=True) test_sample = input_seq.size(0) if args.other is not None: input_seq = input_seq.squeeze(1) logit, feature = model(input_seq) if train_feature is None: train_feature = torch.zeros(len(train_dataset), feature.size(-1), device=feature.device) if args.other is not None: train_feature[sample_id, :] = feature.mean(0) else: train_feature[sample_id, :] = feature[:, -1, :].mean(0) # train_feature.append(feature[:,-1,:].mean(0)) train_label.append(current_target) train_vname.append(vname) sample_id += 1 # train_feature = torch.stack(train_feature, dim=0) print(train_feature.size()) train_label = torch.cat(train_label, dim=0) torch.save( train_feature, os.path.join(os.path.dirname(args.test), dirname, '%s_train_feature.pth.tar' % args.dataset)) torch.save( train_label, os.path.join(os.path.dirname(args.test), dirname, '%s_train_label.pth.tar' % args.dataset)) with open( os.path.join(os.path.dirname(args.test), dirname, '%s_train_vname.pkl' % args.dataset), 'wb') as fp: pickle.dump(train_vname, fp) ks = [1, 5, 10, 20, 50] NN_acc = [] # centering test_feature = test_feature - test_feature.mean(dim=0, keepdim=True) train_feature = train_feature - train_feature.mean(dim=0, keepdim=True) # normalize test_feature = F.normalize(test_feature, p=2, dim=1) train_feature = F.normalize(train_feature, p=2, dim=1) # dot product sim = test_feature.matmul(train_feature.t()) torch.save( sim, os.path.join(os.path.dirname(args.test), dirname, '%s_sim.pth.tar' % args.dataset)) for k in ks: topkval, topkidx = torch.topk(sim, k, dim=1) acc = torch.any(train_label[topkidx] == test_label.unsqueeze(1), dim=1).float().mean().item() NN_acc.append(acc) print('%dNN acc = %.4f' % (k, acc)) args.logger.log('NN-Retrieval on %s:' % args.dataset) for k, acc in zip(ks, NN_acc): args.logger.log('\t%dNN acc = %.4f' % (k, acc)) with open( os.path.join(os.path.dirname(args.test), dirname, '%s_test_vname.pkl' % args.dataset), 'rb') as fp: test_vname = pickle.load(fp) with open( os.path.join(os.path.dirname(args.test), dirname, '%s_train_vname.pkl' % args.dataset), 'rb') as fp: train_vname = pickle.load(fp) sys.exit(0)
def test_10crop(dataset, model, criterion, transforms_cuda, device, epoch, args): prob_dict = {} model.eval() # aug_list: 1,2,3,4,5 = topleft, topright, bottomleft, bottomright, center # flip_list: 0,1 = raw, flip if args.center_crop: print('Test using center crop') args.logger.log('Test using center_crop\n') aug_list = [5] flip_list = [0] title = 'center' if args.five_crop: print('Test using 5 crop') args.logger.log('Test using 5_crop\n') aug_list = [5, 1, 2, 3, 4] flip_list = [0] title = 'five' if args.ten_crop: print('Test using 10 crop') args.logger.log('Test using 10_crop\n') aug_list = [5, 1, 2, 3, 4] flip_list = [0, 1] title = 'ten' def tr(x): B = x.size(0) assert B == 1 num_test_sample = x.size(2) // (args.seq_len * args.num_seq) return transforms_cuda(x)\ .view(3,num_test_sample,args.num_seq,args.seq_len,args.img_dim,args.img_dim).permute(1,2,0,3,4,5) with torch.no_grad(): end = time.time() # for loop through 10 types of augmentations, then average the probability for flip_idx in flip_list: for aug_idx in aug_list: print('Aug type: %d; flip: %d' % (aug_idx, flip_idx)) if flip_idx == 0: transform = transforms.Compose([ A.RandomHorizontalFlip(command='left'), A.FiveCrop(size=(224, 224), where=aug_idx), A.Scale(size=(args.img_dim, args.img_dim)), A.ColorJitter(0.2, 0.2, 0.2, 0.1, p=0.3, consistent=True), A.ToTensor() ]) else: transform = transforms.Compose([ A.RandomHorizontalFlip(command='right'), A.FiveCrop(size=(224, 224), where=aug_idx), A.Scale(size=(args.img_dim, args.img_dim)), A.ColorJitter(0.2, 0.2, 0.2, 0.1, p=0.3, consistent=True), A.ToTensor() ]) dataset.transform = transform dataset.return_path = True dataset.return_label = True test_sampler = data.Sequential(dataset) data_loader = data.DataLoader(dataset, batch_size=1, sampler=test_sampler, shuffle=False, num_workers=args.workers, pin_memory=True) for idx, (input_seq, _) in tqdm(enumerate(data_loader), total=len(data_loader)): input_seq = tr(input_seq.to(device, non_blocking=True)) logit, _ = model(input_seq) # average probability along the temporal window prob_mean = F.softmax(logit, dim=-1).mean(0, keepdim=True) vname = vname[0] if vname not in prob_dict.keys(): prob_dict[vname] = {'mean_prob': [], 'last_prob': []} prob_dict[vname]['mean_prob'].append(prob_mean) prob_dict[vname]['last_prob'].append(prob_last) if (title == 'ten') and (flip_idx == 0) and (aug_idx == 5): print('center-crop result:') acc_1 = summarize_probability( prob_dict, data_loader.dataset.encode_action, 'center') args.logger.log('center-crop:') args.logger.log( 'test Epoch: [{0}]\t' 'Mean: Acc@1: {acc[0].avg:.4f} Acc@5: {acc[1].avg:.4f}' .format(epoch, acc=acc_1)) if (title == 'ten') and (flip_idx == 0): print('five-crop result:') acc_5 = summarize_probability( prob_dict, data_loader.dataset.encode_action, 'five') args.logger.log('five-crop:') args.logger.log( 'test Epoch: [{0}]\t' 'Mean: Acc@1: {acc[0].avg:.4f} Acc@5: {acc[1].avg:.4f}'. format(epoch, acc=acc_5)) print('%s-crop result:' % title) acc_final = summarize_probability(prob_dict, data_loader.dataset.encode_action, 'ten') args.logger.log('%s-crop:' % title) args.logger.log( 'test Epoch: [{0}]\t' 'Mean: Acc@1: {acc[0].avg:.4f} Acc@5: {acc[1].avg:.4f}'.format( epoch, acc=acc_final)) sys.exit(0)
def main(args): torch.manual_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) device = torch.device('cuda') num_gpu = len(str(args.gpu).split(',')) args.batch_size = num_gpu * args.batch_size ### model ### if args.model == 'memdpc': model = MemDPC_BD(sample_size=args.img_dim, num_seq=args.num_seq, seq_len=args.seq_len, network=args.net, pred_step=args.pred_step, mem_size=args.mem_size) else: raise NotImplementedError('wrong model!') model.to(device) model = nn.DataParallel(model) model_without_dp = model.module ### optimizer ### params = model.parameters() optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd) criterion = nn.CrossEntropyLoss() ### data ### transform = transforms.Compose([ A.RandomSizedCrop(size=224, consistent=True, p=1.0), # crop from 256 to 224 A.Scale(size=(args.img_dim, args.img_dim)), A.RandomHorizontalFlip(consistent=True), A.RandomGray(consistent=False, p=0.25), A.ColorJitter(0.5, 0.5, 0.5, 0.25, consistent=False, p=1.0), A.ToTensor(), A.Normalize() ]) train_loader = get_data(transform, 'train') val_loader = get_data(transform, 'val') if 'ucf' in args.dataset: lr_milestones_eps = [300, 400] elif 'k400' in args.dataset: lr_milestones_eps = [120, 160] else: lr_milestones_eps = [1000] # NEVER lr_milestones = [len(train_loader) * m for m in lr_milestones_eps] print('=> Use lr_scheduler: %s eps == %s iters' % (str(lr_milestones_eps), str(lr_milestones))) lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier( ep, gamma=0.1, step=lr_milestones, repeat=1) lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) best_acc = 0 args.iteration = 1 ### restart training ### if args.resume: if os.path.isfile(args.resume): print("=> loading resumed checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume, map_location=torch.device('cpu')) args.start_epoch = checkpoint['epoch'] args.iteration = checkpoint['iteration'] best_acc = checkpoint['best_acc'] model_without_dp.load_state_dict(checkpoint['state_dict']) try: optimizer.load_state_dict(checkpoint['optimizer']) except: print('[WARNING] Not loading optimizer states') print("=> loaded resumed checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("[Warning] no checkpoint found at '{}'".format(args.resume)) sys.exit(0) # logging tools args.img_path, args.model_path = set_path(args) args.logger = Logger(path=args.img_path) args.logger.log('args=\n\t\t' + '\n\t\t'.join( ['%s:%s' % (str(k), str(v)) for k, v in vars(args).items()])) args.writer_val = SummaryWriter(logdir=os.path.join(args.img_path, 'val')) args.writer_train = SummaryWriter( logdir=os.path.join(args.img_path, 'train')) torch.backends.cudnn.benchmark = True ### main loop ### for epoch in range(args.start_epoch, args.epochs): np.random.seed(epoch) random.seed(epoch) train_loss, train_acc = train_one_epoch(train_loader, model, criterion, optimizer, lr_scheduler, device, epoch, args) val_loss, val_acc = validate(val_loader, model, criterion, device, epoch, args) # save check_point is_best = val_acc > best_acc best_acc = max(val_acc, best_acc) save_dict = { 'epoch': epoch, 'state_dict': model_without_dp.state_dict(), 'best_acc': best_acc, 'optimizer': optimizer.state_dict(), 'iteration': args.iteration } save_checkpoint(save_dict, is_best, filename=os.path.join(args.model_path, 'epoch%s.pth.tar' % str(epoch)), keep_all=False) print('Training from ep %d to ep %d finished' % (args.start_epoch, args.epochs)) sys.exit(0)
def get_data(args, mode='train', return_label=False, hierarchical_label=False, action_level_gt=False,\ num_workers=0, path_dataset=''): if hierarchical_label and args.dataset not in ['finegym', 'hollywood2']: raise Exception('Hierarchical information is only implemented in finegym and hollywood2 datasets') if return_label and not action_level_gt and args.dataset != 'finegym': raise Exception('subaction only subactions available in finegym dataset') if mode == 'train': if args.dataset == 'ucf101': # designed for ucf101, short size=256, rand crop to 224x224 then scale to 128x128 transform = transforms.Compose([ augmentation.RandomHorizontalFlip(consistent=True), augmentation.RandomCrop(size=224, consistent=True), augmentation.Scale(size=(args.img_dim, args.img_dim)), augmentation.RandomGray(consistent=False, p=0.5), augmentation.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25, p=1.0), augmentation.ToTensor(), augmentation.Normalize() ]) # designed for kinetics400, short size=150, rand crop to 128x128 else: transform = transforms.Compose([ augmentation.RandomSizedCrop(size=args.img_dim, consistent=True, p=1.0), augmentation.RandomHorizontalFlip(consistent=True), augmentation.RandomGray(consistent=False, p=0.5), augmentation.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25, p=1.0), augmentation.ToTensor(), augmentation.Normalize() ]) else: transform = transforms.Compose([ augmentation.CenterCrop(size=args.img_dim, consistent=True), augmentation.ToTensor(), augmentation.Normalize() ]) if args.dataset == 'k600': dataset = Kinetics600_full_3d(mode=mode, transform=transform, seq_len=args.seq_len, num_seq=args.num_seq, downsample=5, return_label=return_label, return_idx=args.viz, path_dataset=path_dataset) elif args.dataset == 'ucf101': dataset = UCF101_3d(mode=mode, transform=transform, seq_len=args.seq_len, num_seq=args.num_seq, downsample=args.ds, return_label=return_label) elif args.dataset == 'hollywood2': if return_label: assert action_level_gt, 'hollywood2 does not have subaction labels' dataset = Hollywood2(mode=mode, transform=transform, seq_len=args.seq_len, num_seq=args.num_seq, downsample=args.ds, return_label=return_label, hierarchical_label=hierarchical_label) elif args.dataset == 'finegym': if hierarchical_label: assert not action_level_gt, 'finegym does not have hierarchical information at the action level' dataset = FineGym(mode=mode, transform=transform, seq_len=args.seq_len, num_seq=args.num_seq, fps=int(25/args.ds), # approx return_label=return_label, hierarchical_label=hierarchical_label, action_level_gt=action_level_gt, path_dataset=path_dataset, return_idx=args.viz) elif args.dataset == 'movienet': assert not return_label, 'Not yet implemented (actions not available online)' assert args.seq_len == 3, 'We only have 3 frames per subclip/scene, but always 3' dataset = MovieNet(mode=mode, transform=transform, num_seq=args.num_seq, path_dataset=path_dataset) else: raise ValueError('dataset not supported') sampler = data.RandomSampler(dataset) if mode == 'train' or args.viz else data.SequentialSampler(dataset) data_loader = data.DataLoader(dataset, batch_size=args.batch_size, sampler=sampler, shuffle=False, num_workers=num_workers, pin_memory=True, drop_last=(mode != 'test') # test always same examples independently of batch size ) return data_loader
def main(args): torch.manual_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) device = torch.device('cuda') num_gpu = len(str(args.gpu).split(',')) args.batch_size = num_gpu * args.batch_size if args.dataset == 'ucf101': args.num_class = 101 elif args.dataset == 'hmdb51': args.num_class = 51 ### classifier model ### if args.model == 'lc': model = LC(sample_size=args.img_dim, num_seq=args.num_seq, seq_len=args.seq_len, network=args.net, num_class=args.num_class, dropout=args.dropout, train_what=args.train_what) else: raise ValueError('wrong model!') model.to(device) model = nn.DataParallel(model) model_without_dp = model.module criterion = nn.CrossEntropyLoss() ### optimizer ### params = None if args.train_what == 'ft': print('=> finetune backbone with smaller lr') params = [] for name, param in model.module.named_parameters(): if ('resnet' in name) or ('rnn' in name): params.append({'params': param, 'lr': args.lr / 10}) else: params.append({'params': param}) elif args.train_what == 'last': print('=> train only last layer') params = [] for name, param in model.named_parameters(): if ('bone' in name) or ('agg' in name) or ('mb' in name) or ( 'network_pred' in name): param.requires_grad = False else: params.append({'params': param}) else: pass # train all layers print('\n===========Check Grad============') for name, param in model.named_parameters(): print(name, param.requires_grad) print('=================================\n') if params is None: params = model.parameters() optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd) ### scheduler ### if args.dataset == 'hmdb51': step = args.schedule if step == []: step = [150, 250] lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier( ep, gamma=0.1, step=step, repeat=1) elif args.dataset == 'ucf101': step = args.schedule if step == []: step = [300, 400] lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier( ep, gamma=0.1, step=step, repeat=1) lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) print('=> Using scheduler at {} epochs'.format(step)) args.old_lr = None best_acc = 0 args.iteration = 1 ### if in test mode ### if args.test: if os.path.isfile(args.test): print("=> loading test checkpoint '{}'".format(args.test)) checkpoint = torch.load(args.test, map_location=torch.device('cpu')) try: model_without_dp.load_state_dict(checkpoint['state_dict']) except: print( '=> [Warning]: weight structure is not equal to test model; Load anyway ==' ) model_without_dp = neq_load_customized( model_without_dp, checkpoint['state_dict']) epoch = checkpoint['epoch'] print("=> loaded testing checkpoint '{}' (epoch {})".format( args.test, checkpoint['epoch'])) elif args.test == 'random': epoch = 0 print("=> loaded random weights") else: print("=> no checkpoint found at '{}'".format(args.test)) sys.exit(0) args.logger = Logger(path=os.path.dirname(args.test)) _, test_dataset = get_data(None, 'test') test_loss, test_acc = test(test_dataset, model, criterion, device, epoch, args) sys.exit() ### restart training ### if args.resume: if os.path.isfile(args.resume): print("=> loading resumed checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume, map_location=torch.device('cpu')) args.start_epoch = checkpoint['epoch'] args.iteration = checkpoint['iteration'] best_acc = checkpoint['best_acc'] model_without_dp.load_state_dict(checkpoint['state_dict']) try: optimizer.load_state_dict(checkpoint['optimizer']) except: print('[WARNING] Not loading optimizer states') print("=> loaded resumed checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) sys.exit(0) if (not args.resume) and args.pretrain: if args.pretrain == 'random': print('=> using random weights') elif os.path.isfile(args.pretrain): print("=> loading pretrained checkpoint '{}'".format( args.pretrain)) checkpoint = torch.load(args.pretrain, map_location=torch.device('cpu')) model_without_dp = neq_load_customized(model_without_dp, checkpoint['state_dict']) print("=> loaded pretrained checkpoint '{}' (epoch {})".format( args.pretrain, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.pretrain)) sys.exit(0) ### data ### transform = transforms.Compose([ A.RandomSizedCrop(consistent=True, size=224, p=1.0), A.Scale(size=(args.img_dim, args.img_dim)), A.RandomHorizontalFlip(consistent=True), A.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25, p=0.3, consistent=True), A.ToTensor(), A.Normalize() ]) val_transform = transforms.Compose([ A.RandomSizedCrop(consistent=True, size=224, p=0.3), A.Scale(size=(args.img_dim, args.img_dim)), A.RandomHorizontalFlip(consistent=True), A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.3, consistent=True), A.ToTensor(), A.Normalize() ]) train_loader, _ = get_data(transform, 'train') val_loader, _ = get_data(val_transform, 'val') # setup tools args.img_path, args.model_path = set_path(args) args.writer_val = SummaryWriter(logdir=os.path.join(args.img_path, 'val')) args.writer_train = SummaryWriter( logdir=os.path.join(args.img_path, 'train')) torch.backends.cudnn.benchmark = True ### main loop ### for epoch in range(args.start_epoch, args.epochs): train_loss, train_acc = train_one_epoch(train_loader, model, criterion, optimizer, device, epoch, args) val_loss, val_acc = validate(val_loader, model, criterion, device, epoch, args) lr_scheduler.step(epoch) # save check_point is_best = val_acc > best_acc best_acc = max(val_acc, best_acc) save_dict = { 'epoch': epoch, 'backbone': args.net, 'state_dict': model_without_dp.state_dict(), 'best_acc': best_acc, 'optimizer': optimizer.state_dict(), 'iteration': args.iteration } save_checkpoint(save_dict, is_best, filename=os.path.join(args.model_path, 'epoch%s.pth.tar' % str(epoch)), keep_all=False) print('Training from ep %d to ep %d finished' % (args.start_epoch, args.epochs)) sys.exit(0)
def test(dataset, model, criterion, device, epoch, args): # 10-crop then average the probability prob_dict = {} model.eval() # aug_list: 1,2,3,4,5 = top-left, top-right, bottom-left, bottom-right, center # flip_list: 0,1 = original, horizontal-flip if args.center_crop: print('Test using center crop') args.logger.log('Test using center_crop\n') aug_list = [5] flip_list = [0] title = 'center' if args.five_crop: print('Test using 5 crop') args.logger.log('Test using 5_crop\n') aug_list = [5, 1, 2, 3, 4] flip_list = [0] title = 'five' if args.ten_crop: print('Test using 10 crop') args.logger.log('Test using 10_crop\n') aug_list = [5, 1, 2, 3, 4] flip_list = [0, 1] title = 'ten' with torch.no_grad(): end = time.time() for flip_idx in flip_list: for aug_idx in aug_list: print('Aug type: %d; flip: %d' % (aug_idx, flip_idx)) if flip_idx == 0: transform = transforms.Compose([ A.RandomHorizontalFlip(command='left'), A.FiveCrop(size=(224, 224), where=aug_idx), A.Scale(size=(args.img_dim, args.img_dim)), A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.3, consistent=True), A.ToTensor(), ]) else: transform = transforms.Compose([ A.RandomHorizontalFlip(command='right'), A.FiveCrop(size=(224, 224), where=aug_idx), A.Scale(size=(args.img_dim, args.img_dim)), A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.3, consistent=True), A.ToTensor(), ]) dataset.transform = transform dataset.return_path = True dataset.return_label = True data_sampler = data.RandomSampler(dataset) data_loader = data.DataLoader(dataset, batch_size=1, sampler=data_sampler, shuffle=False, num_workers=16, pin_memory=True) for idx, (input_seq, target) in tqdm(enumerate(data_loader), total=len(data_loader)): B = 1 input_seq = input_seq.to(device) target, vname = target target = target.to(device) input_seq = input_seq.squeeze( 0) # squeeze the '1' batch dim output, _ = model(input_seq) prob_mean = nn.functional.softmax(output, 2).mean(1).mean( 0, keepdim=True) vname = vname[0] if vname not in prob_dict.keys(): prob_dict[vname] = [] prob_dict[vname].append(prob_mean) # show intermediate result if (title == 'ten') and (flip_idx == 0) and (aug_idx == 5): print('center-crop result:') acc_1 = summarize_probability( prob_dict, data_loader.dataset.encode_action, 'center') args.logger.log('center-crop:') args.logger.log( 'test Epoch: [{0}]\t' 'Mean: Acc@1: {acc[0].avg:.4f} Acc@5: {acc[1].avg:.4f}' .format(epoch, acc=acc_1)) # show intermediate result if (title == 'ten') and (flip_idx == 0): print('five-crop result:') acc_5 = summarize_probability( prob_dict, data_loader.dataset.encode_action, 'five') args.logger.log('five-crop:') args.logger.log( 'test Epoch: [{0}]\t' 'Mean: Acc@1: {acc[0].avg:.4f} Acc@5: {acc[1].avg:.4f}'. format(epoch, acc=acc_5)) # show final result print('%s-crop result:' % title) acc_final = summarize_probability(prob_dict, data_loader.dataset.encode_action, 'ten') args.logger.log('%s-crop:' % title) args.logger.log( 'test Epoch: [{0}]\t' 'Mean: Acc@1: {acc[0].avg:.4f} Acc@5: {acc[1].avg:.4f}'.format( epoch, acc=acc_final)) sys.exit(0)