def evaluate_model(num_class): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = TBN(num_class, 1, args.modality, base_model=args.arch, consensus_type=args.crop_fusion_type, dropout=args.dropout, midfusion=args.midfusion) weights = '{weights_dir}/model_best.pth.tar'.format( weights_dir=args.weights_dir) checkpoint = torch.load(weights) print("model epoch {} best prec@1: {}".format(checkpoint['epoch'], checkpoint['best_prec1'])) base_dict = { '.'.join(k.split('.')[1:]): v for k, v in list(checkpoint['state_dict'].items()) } net.load_state_dict(base_dict) test_transform = {} image_tmpl = {} for m in args.modality: if m != 'Spec': if args.test_crops == 1: cropping = torchvision.transforms.Compose([ GroupScale(net.scale_size[m]), GroupCenterCrop(net.input_size[m]), ]) elif args.test_crops == 10: cropping = torchvision.transforms.Compose( [GroupOverSample(net.input_size[m], net.scale_size[m])]) else: raise ValueError("Only 1 and 10 crops are supported" + " while we got {}".format(args.test_crops)) test_transform[m] = torchvision.transforms.Compose([ cropping, Stack(roll=args.arch == 'BNInception'), ToTorchFormatTensor(div=args.arch != 'BNInception'), GroupNormalize(net.input_mean[m], net.input_std[m]), ]) # Prepare dictionaries containing image name templates # for each modality if m in ['RGB', 'RGBDiff']: image_tmpl[m] = "img_{:010d}.jpg" elif m == 'Flow': image_tmpl[m] = args.flow_prefix + "{}_{:010d}.jpg" else: test_transform[m] = torchvision.transforms.Compose([ Stack(roll=args.arch == 'BNInception'), ToTorchFormatTensor(div=False), ]) data_length = net.new_length test_loader = torch.utils.data.DataLoader(TBNDataSet( args.dataset, pd.read_pickle(args.test_list), data_length, args.modality, image_tmpl, visual_path=args.visual_path, audio_path=args.audio_path, num_segments=args.test_segments, mode='test', transform=test_transform, resampling_rate=args.resampling_rate), batch_size=1, shuffle=False, num_workers=args.workers * 2) net = torch.nn.DataParallel(net, device_ids=args.gpus).to(device) with torch.no_grad(): net.eval() results = [] total_num = len(test_loader.dataset) proc_start_time = time.time() max_num = args.max_num if args.max_num > 0 else total_num for i, (data, label) in enumerate(test_loader): if i >= max_num: break rst = eval_video(data, net, num_class, device) if label != -10000: # label exists if 'epic' not in args.dataset: label_ = label.item() else: label_ = {k: v.item() for k, v in label.items()} results.append((rst, label_)) else: # Test set (S1/S2) results.append((rst, )) cnt_time = time.time() - proc_start_time print('video {} done, total {}/{}, average {} sec/video'.format( i, i + 1, total_num, float(cnt_time) / (i + 1))) return results
def main(): global args, best_prec1, train_list, experiment_dir, best_loss args = parser.parse_args() if args.dataset == 'ucf101': num_class = 101 elif args.dataset == 'hmdb51': num_class = 51 elif args.dataset == 'kinetics': num_class = 400 elif args.dataset == 'epic': num_class = (125, 352) else: raise ValueError('Unknown dataset ' + args.dataset) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = TBN(num_class, args.num_segments, args.modality, base_model=args.arch, consensus_type=args.consensus_type, dropout=args.dropout, midfusion=args.midfusion) crop_size = model.crop_size scale_size = model.scale_size input_mean = model.input_mean input_std = model.input_std data_length = model.new_length # policies = model.get_optim_policies() train_augmentation = model.get_augmentation() # Resume training from a checkpoint if args.resume: if os.path.isfile(args.resume): print(("=> loading checkpoint '{}'".format(args.resume))) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] state_dict_new = OrderedDict() for k, v in checkpoint['state_dict'].items(): state_dict_new[k.split('.', 1)[1]] = v model.load_state_dict(state_dict_new) print(("=> loaded checkpoint '{}' (epoch {})".format( args.evaluate, checkpoint['epoch']))) else: print(("=> no checkpoint found at '{}'".format(args.resume))) # Load pretrained weights for each stream if args.pretrained_flow_weights: print('Initialize Flow stream from Kinetics') pretrained = os.path.join('pretrained/kinetics_tsn_flow.pth.tar') state_dict = torch.load(pretrained) for k, v in state_dict.items(): state_dict[k] = torch.squeeze(v, dim=0) base_model = getattr(model, 'flow') base_model.load_state_dict(state_dict, strict=False) # Freeze stream weights (leaves only fusion and classification trainable) if args.freeze: model.freeze_fn('modalities') # Freeze batch normalisation layers except the first if args.partialbn: model.freeze_fn('partialbn_parameters') model = torch.nn.DataParallel(model, device_ids=args.gpus).to(device) cudnn.benchmark = True # Data loading code normalize = {} for m in args.modality: if (m != 'Spec'): if (m != 'RGBDiff'): normalize[m] = GroupNormalize(input_mean[m], input_std[m]) else: normalize[m] = IdentityTransform() image_tmpl = {} train_transform = {} val_transform = {} for m in args.modality: if (m != 'Spec'): # Prepare dictionaries containing image name templates for each modality if m in ['RGB', 'RGBDiff']: image_tmpl[m] = "img_{:010d}.jpg" elif m == 'Flow': image_tmpl[m] = args.flow_prefix + "{}_{:010d}.jpg" # Prepare train/val dictionaries containing the transformations # (augmentation+normalization) # for each modality train_transform[m] = torchvision.transforms.Compose([ train_augmentation[m], Stack(roll=args.arch == 'BNInception'), ToTorchFormatTensor(div=args.arch != 'BNInception'), normalize[m], ]) val_transform[m] = torchvision.transforms.Compose([ GroupScale(int(scale_size[m])), GroupCenterCrop(crop_size[m]), Stack(roll=args.arch == 'BNInception'), ToTorchFormatTensor(div=args.arch != 'BNInception'), normalize[m], ]) else: # Prepare train/val dictionaries containing the transformations # (augmentation+normalization) # for each modality train_transform[m] = torchvision.transforms.Compose([ Stack(roll=args.arch == 'BNInception'), ToTorchFormatTensor(div=False), ]) val_transform[m] = torchvision.transforms.Compose([ Stack(roll=args.arch == 'BNInception'), ToTorchFormatTensor(div=False), ]) if args.train_list is None: # If train_list is not provided, we train on the default # dataset which is all the training set train_loader = torch.utils.data.DataLoader(TBNDataSet( args.dataset, training_labels(), data_length, args.modality, image_tmpl, visual_path=args.visual_path, audio_path=args.audio_path, num_segments=args.num_segments, transform=train_transform, fps=args.fps, resampling_rate=args.resampling_rate), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) else: train_loader = torch.utils.data.DataLoader(TBNDataSet( args.dataset, args.train_list, data_length, args.modality, image_tmpl, visual_path=args.visual_path, audio_path=args.audio_path, num_segments=args.num_segments, transform=train_transform, fps=args.fps, resampling_rate=args.resampling_rate), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) if args.train_list is not None: # we cannot validate on part of the training set # if we use all the training set for training val_loader = torch.utils.data.DataLoader(TBNDataSet( args.dataset, args.val_list, data_length, args.modality, image_tmpl, visual_path=args.visual_path, audio_path=args.audio_path, num_segments=args.num_segments, mode='val', transform=val_transform, fps=args.fps, resampling_rate=args.resampling_rate), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) # define loss function (criterion) and optimizer criterion = torch.nn.CrossEntropyLoss() if len(args.modality) > 1: param_groups = [ { 'params': filter(lambda p: p.requires_grad, model.module.rgb.parameters()) }, { 'params': filter(lambda p: p.requires_grad, model.module.flow.parameters()), 'lr': 0.001 }, { 'params': filter(lambda p: p.requires_grad, model.module.spec.parameters()) }, { 'params': filter(lambda p: p.requires_grad, model.module.fusion_classification_net.parameters()) }, ] else: param_groups = filter(lambda p: p.requires_grad, model.parameters()) optimizer = torch.optim.SGD(param_groups, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) scheduler = MultiStepLR(optimizer, args.lr_steps, gamma=0.1) if args.evaluate: validate(val_loader, model, criterion, device) return if args.save_stats: if args.dataset != 'epic': stats_dict = { 'train_loss': np.zeros((args.epochs, )), 'val_loss': np.zeros((args.epochs, )), 'train_acc': np.zeros((args.epochs, )), 'val_acc': np.zeros((args.epochs, )) } elif args.dataset == 'epic': if args.train_list is not None: stats_dict = { 'train_loss': np.zeros((args.epochs, )), 'train_verb_loss': np.zeros((args.epochs, )), 'train_noun_loss': np.zeros((args.epochs, )), 'train_acc': np.zeros((args.epochs, )), 'train_verb_acc': np.zeros((args.epochs, )), 'train_noun_acc': np.zeros((args.epochs, )), 'val_loss': np.zeros((args.epochs, )), 'val_verb_loss': np.zeros((args.epochs, )), 'val_noun_loss': np.zeros((args.epochs, )), 'val_acc': np.zeros((args.epochs, )), 'val_verb_acc': np.zeros((args.epochs, )), 'val_noun_acc': np.zeros((args.epochs, )) } else: stats_dict = { 'train_loss': np.zeros((args.epochs, )), 'train_verb_loss': np.zeros((args.epochs, )), 'train_noun_loss': np.zeros((args.epochs, )), 'train_acc': np.zeros((args.epochs, )), 'train_verb_acc': np.zeros((args.epochs, )), 'train_noun_acc': np.zeros((args.epochs, )) } for epoch in range(args.start_epoch, args.epochs): scheduler.step() # train for one epoch training_metrics = train(train_loader, model, criterion, optimizer, epoch, device) if args.save_stats: for k, v in training_metrics.items(): stats_dict[k][epoch] = v # evaluate on validation set if args.train_list is not None: if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: test_metrics = validate(val_loader, model, criterion, device) if args.save_stats: for k, v in test_metrics.items(): stats_dict[k][epoch] = v prec1 = test_metrics['val_acc'] # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, is_best) else: # No validation set save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': training_metrics['train_acc'], }, False) summaryWriter.close() if args.save_stats: save_stats_dir = os.path.join('stats', experiment_dir) if not os.path.exists(save_stats_dir): os.makedirs(save_stats_dir) with open(os.path.join(save_stats_dir, 'training_stats.npz'), 'wb') as f: np.savez(f, **stats_dict)