num_class, this_test_segments if is_shift else 1, modality, base_model=this_arch, consensus_type=args.crop_fusion_type, img_feature_dim=args.img_feature_dim, pretrain=args.pretrain, is_shift=is_shift, shift_div=shift_div, shift_place=shift_place, non_local='_nl' in this_weights, ) if 'tpool' in this_weights: from ops.temporal_shift import make_temporal_pool make_temporal_pool(net.base_model, this_test_segments) # since DataParallel checkpoint = torch.load(this_weights) checkpoint = checkpoint['state_dict'] # base_dict = {('base_model.' + k).replace('base_model.fc', 'new_fc'): v for k, v in list(checkpoint.items())} base_dict = { '.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items()) } replace_dict = { 'base_model.classifier.weight': 'new_fc.weight', 'base_model.classifier.bias': 'new_fc.bias', } for k, v in replace_dict.items(): if k in base_dict:
def main(): global args, best_prec1 args = parser.parse_args() num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset( args.dataset, args.modality) full_arch_name = args.arch if args.shift: full_arch_name += '_shift{}_{}'.format(args.shift_div, args.shift_place) if args.temporal_pool: full_arch_name += '_tpool' args.store_name = '_'.join([ 'TSM', args.dataset, args.modality, full_arch_name, args.consensus_type, 'segment%d' % args.num_segments, 'e{}'.format(args.epochs) ]) if args.pretrain != 'imagenet': args.store_name += '_{}'.format(args.pretrain) if args.lr_type != 'step': args.store_name += '_{}'.format(args.lr_type) if args.dense_sample: args.store_name += '_dense' if args.non_local > 0: args.store_name += '_nl' if args.suffix is not None: args.store_name += '_{}'.format(args.suffix) print('storing name: ' + args.store_name) check_rootfolders() model = TSN(num_class, args.num_segments, args.modality, base_model=args.arch, consensus_type=args.consensus_type, dropout=args.dropout, img_feature_dim=args.img_feature_dim, partial_bn=not args.no_partialbn, pretrain=args.pretrain, is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place, fc_lr5=not (args.tune_from and args.dataset in args.tune_from), temporal_pool=args.temporal_pool, non_local=args.non_local) crop_size = model.crop_size scale_size = model.scale_size input_mean = model.input_mean input_std = model.input_std policies = model.get_optim_policies() train_augmentation = model.get_augmentation( flip=False if 'something' in args.dataset or 'jester' in args.dataset else True) model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda() optimizer = torch.optim.SGD(policies, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.resume: if args.temporal_pool: # early temporal pool so that we can load the state_dict make_temporal_pool(model.module.base_model, args.num_segments) if os.path.isfile(args.resume): print(("=> loading checkpoint '{}'".format(args.resume))) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print(("=> loaded checkpoint '{}' (epoch {})".format( args.evaluate, checkpoint['epoch']))) else: print(("=> no checkpoint found at '{}'".format(args.resume))) if args.tune_from: print(("=> fine-tuning from '{}'".format(args.tune_from))) sd = torch.load(args.tune_from) sd = sd['state_dict'] model_dict = model.state_dict() replace_dict = [] for k, v in sd.items(): if k not in model_dict and k.replace('.net', '') in model_dict: print('=> Load after remove .net: ', k) replace_dict.append((k, k.replace('.net', ''))) for k, v in model_dict.items(): if k not in sd and k.replace('.net', '') in sd: print('=> Load after adding .net: ', k) replace_dict.append((k.replace('.net', ''), k)) for k, k_new in replace_dict: sd[k_new] = sd.pop(k) keys1 = set(list(sd.keys())) keys2 = set(list(model_dict.keys())) set_diff = (keys1 - keys2) | (keys2 - keys1) print('#### Notice: keys that failed to load: {}'.format(set_diff)) if args.dataset not in args.tune_from: # new dataset print('=> New dataset, do not load fc weights') sd = {k: v for k, v in sd.items() if 'fc' not in k} if args.modality == 'Flow' and 'Flow' not in args.tune_from: sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k} model_dict.update(sd) model.load_state_dict(model_dict) if args.temporal_pool and not args.resume: make_temporal_pool(model.module.base_model, args.num_segments) cudnn.benchmark = True # Data loading code if args.modality != 'RGBDiff': normalize = GroupNormalize(input_mean, input_std) else: normalize = IdentityTransform() if args.modality == 'RGB': data_length = 1 elif args.modality in ['Flow', 'RGBDiff']: data_length = 5 train_loader = torch.utils.data.DataLoader( TSNDataSet( args.root_path, args.train_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, transform=torchvision.transforms.Compose([ train_augmentation, Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor( div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize, ]), dense_sample=args.dense_sample), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) # prevent something not % n_GPU val_loader = torch.utils.data.DataLoader(TSNDataSet( args.root_path, args.val_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, random_shift=False, transform=torchvision.transforms.Compose([ GroupScale(int(scale_size)), GroupCenterCrop(crop_size), Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor( div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize, ]), dense_sample=args.dense_sample), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) # define loss function (criterion) and optimizer if args.loss_type == 'nll': criterion = torch.nn.CrossEntropyLoss().cuda() else: raise ValueError("Unknown loss type") for group in policies: print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format( group['name'], len(group['params']), group['lr_mult'], group['decay_mult']))) if args.evaluate: validate(val_loader, model, criterion, 0) return log_training = open( os.path.join(args.root_log, args.store_name, 'log.csv'), 'w') with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f: f.write(str(args)) tf_writer = SummaryWriter( log_dir=os.path.join(args.root_log, args.store_name)) for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps) # train for one epoch train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer) # evaluate on validation set if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: prec1 = validate(val_loader, model, criterion, epoch, log_training, tf_writer) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch) output_best = 'Best Prec@1: %.3f\n' % (best_prec1) print(output_best) log_training.write(output_best + '\n') log_training.flush() save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_prec1': best_prec1, }, is_best)
def main(): # options parser = argparse.ArgumentParser(description="TSM testing on the full validation set") parser.add_argument('dataset', type=str) # may contain splits parser.add_argument('--weights', type=str, default=None) parser.add_argument('--test_segments', type=str, default=25) parser.add_argument('--dense_sample', default=False, action="store_true", help='use dense sample as I3D') parser.add_argument('--twice_sample', default=False, action="store_true", help='use twice sample for ensemble') parser.add_argument('--full_res', default=False, action="store_true", help='use full resolution 256x256 for test as in Non-local I3D') parser.add_argument('--test_crops', type=int, default=1) parser.add_argument('--coeff', type=str, default=None) parser.add_argument('--batch_size', type=int, default=1) parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', help='number of data loading workers (default: 8)') # for true test parser.add_argument('--test_list', type=str, default=None) parser.add_argument('--csv_file', type=str, default=None) parser.add_argument('--softmax', default=False, action="store_true", help='use softmax') parser.add_argument('--max_num', type=int, default=-1) parser.add_argument('--input_size', type=int, default=224) parser.add_argument('--crop_fusion_type', type=str, default='avg') parser.add_argument('--gpus', nargs='+', type=int, default=None) parser.add_argument('--img_feature_dim',type=int, default=256) parser.add_argument('--num_set_segments',type=int, default=1,help='TODO: select multiply set of n-frames from a video') parser.add_argument('--pretrain', type=str, default='imagenet') args = parser.parse_args() def accuracy(output, target, topk=(1,)): """Computes the precision@k for the specified values of k""" maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0) res.append(correct_k.mul_(100.0 / batch_size)) return res def parse_shift_option_from_log_name(log_name): if 'shift' in log_name: strings = log_name.split('_') for i, s in enumerate(strings): if 'shift' in s: break return True, int(strings[i].replace('shift', '')), strings[i + 1] else: return False, None, None weights_list = args.weights.split(',') test_segments_list = [int(s) for s in args.test_segments.split(',')] assert len(weights_list) == len(test_segments_list) if args.coeff is None: coeff_list = [1] * len(weights_list) else: coeff_list = [float(c) for c in args.coeff.split(',')] if args.test_list is not None: test_file_list = args.test_list.split(',') else: test_file_list = [None] * len(weights_list) data_iter_list = [] net_list = [] modality_list = [] total_num = None for this_weights, this_test_segments, test_file in zip(weights_list, test_segments_list, test_file_list): is_shift, shift_div, shift_place = parse_shift_option_from_log_name(this_weights) if 'RGB' in this_weights: modality = 'RGB' else: modality = 'Flow' this_arch = this_weights.split('TSM_')[1].split('_')[2] modality_list.append(modality) num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset(args.dataset, modality) print('=> shift: {}, shift_div: {}, shift_place: {}'.format(is_shift, shift_div, shift_place)) net = TSN(num_class, this_test_segments if is_shift else 1, modality, base_model=this_arch, consensus_type=args.crop_fusion_type, img_feature_dim=args.img_feature_dim, pretrain=args.pretrain, is_shift=is_shift, shift_div=shift_div, shift_place=shift_place, non_local='_nl' in this_weights, ) if 'tpool' in this_weights: from ops.temporal_shift import make_temporal_pool make_temporal_pool(net.base_model, this_test_segments) # since DataParallel checkpoint = torch.load(this_weights) checkpoint = checkpoint['state_dict'] # base_dict = {('base_model.' + k).replace('base_model.fc', 'new_fc'): v for k, v in list(checkpoint.items())} base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())} replace_dict = {'base_model.classifier.weight': 'new_fc.weight', 'base_model.classifier.bias': 'new_fc.bias', } for k, v in replace_dict.items(): if k in base_dict: base_dict[v] = base_dict.pop(k) net.load_state_dict(base_dict) input_size = net.scale_size if args.full_res else net.input_size if args.test_crops == 1: cropping = torchvision.transforms.Compose([ GroupScale(net.scale_size), GroupCenterCrop(input_size), ]) elif args.test_crops == 3: # do not flip, so only 5 crops cropping = torchvision.transforms.Compose([ GroupFullResSample(input_size, net.scale_size, flip=False) ]) elif args.test_crops == 5: # do not flip, so only 5 crops cropping = torchvision.transforms.Compose([ GroupOverSample(input_size, net.scale_size, flip=False) ]) elif args.test_crops == 10: cropping = torchvision.transforms.Compose([ GroupOverSample(input_size, net.scale_size) ]) else: raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(args.test_crops)) data_loader = torch.utils.data.DataLoader( TSNDataSet(root_path, test_file if test_file is not None else val_list, num_segments=this_test_segments, new_length=1 if modality == "RGB" else 5, modality=modality, image_tmpl=prefix, test_mode=True, remove_missing=len(weights_list) == 1, transform=torchvision.transforms.Compose([ cropping, Stack(roll=(this_arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor(div=(this_arch not in ['BNInception', 'InceptionV3'])), GroupNormalize(net.input_mean, net.input_std), ]), dense_sample=args.dense_sample, twice_sample=args.twice_sample), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, ) if args.gpus is not None: devices = [args.gpus[i] for i in range(args.workers)] else: devices = list(range(args.workers)) net = torch.nn.DataParallel(net.cuda()) net.eval() data_gen = enumerate(data_loader) if total_num is None: total_num = len(data_loader.dataset) else: assert total_num == len(data_loader.dataset) data_iter_list.append(data_gen) net_list.append(net) output = [] def eval_video(video_data, net, this_test_segments, modality): net.eval() with torch.no_grad(): i, data, label = video_data batch_size = label.numel() num_crop = args.test_crops if args.dense_sample: num_crop *= 10 # 10 clips for testing when using dense sample if args.twice_sample: num_crop *= 2 if modality == 'RGB': length = 3 elif modality == 'Flow': length = 10 elif modality == 'RGBDiff': length = 18 else: raise ValueError("Unknown modality "+ modality) data_in = data.view(-1, length, data.size(2), data.size(3)) if is_shift: data_in = data_in.view(batch_size * num_crop, this_test_segments, length, data_in.size(2), data_in.size(3)) rst = net(data_in) rst = rst.reshape(batch_size, num_crop, -1).mean(1) if args.softmax: # take the softmax to normalize the output to probability rst = F.softmax(rst, dim=1) rst = rst.data.cpu().numpy().copy() if net.module.is_shift: rst = rst.reshape(batch_size, num_class) else: rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class)) return i, rst, label proc_start_time = time.time() max_num = args.max_num if args.max_num > 0 else total_num top1 = AverageMeter() top5 = AverageMeter() for i, data_label_pairs in enumerate(zip(*data_iter_list)): with torch.no_grad(): if i >= max_num: break this_rst_list = [] this_label = None for n_seg, (_, (data, label)), net, modality in zip(test_segments_list, data_label_pairs, net_list, modality_list): rst = eval_video((i, data, label), net, n_seg, modality) this_rst_list.append(rst[1]) this_label = label assert len(this_rst_list) == len(coeff_list) for i_coeff in range(len(this_rst_list)): this_rst_list[i_coeff] *= coeff_list[i_coeff] ensembled_predict = sum(this_rst_list) / len(this_rst_list) for p, g in zip(ensembled_predict, this_label.cpu().numpy()): output.append([p[None, ...], g]) cnt_time = time.time() - proc_start_time prec1, prec5 = accuracy(torch.from_numpy(ensembled_predict), this_label, topk=(1, 5)) top1.update(prec1.item(), this_label.numel()) top5.update(prec5.item(), this_label.numel()) if i % 20 == 0: print('video {} done, total {}/{}, average {:.3f} sec/video, ' 'moving Prec@1 {:.3f} Prec@5 {:.3f}'.format(i * args.batch_size, i * args.batch_size, total_num, float(cnt_time) / (i+1) / args.batch_size, top1.avg, top5.avg)) video_pred = [np.argmax(x[0]) for x in output] video_pred_top5 = [np.argsort(np.mean(x[0], axis=0).reshape(-1))[::-1][:5] for x in output] video_labels = [x[1] for x in output] if args.csv_file is not None: print('=> Writing result to csv file: {}'.format(args.csv_file)) with open(test_file_list[0].replace('test_videofolder.txt', 'category.txt')) as f: categories = f.readlines() categories = [f.strip() for f in categories] with open(test_file_list[0]) as f: vid_names = f.readlines() vid_names = [n.split(' ')[0] for n in vid_names] print(vid_names) assert len(vid_names) == len(video_pred) if args.dataset != 'somethingv2': # only output top1 with open(args.csv_file, 'w') as f: for n, pred in zip(vid_names, video_pred): f.write('{};{}\n'.format(n, categories[pred])) else: with open(args.csv_file, 'w') as f: for n, pred5 in zip(vid_names, video_pred_top5): fill = [n] for p in list(pred5): fill.append(p) f.write('{};{};{};{};{};{}\n'.format(*fill)) cf = confusion_matrix(video_labels, video_pred).astype(float) np.save('cm.npy', cf) cls_cnt = cf.sum(axis=1) cls_hit = np.diag(cf) cls_acc = cls_hit / cls_cnt print(cls_acc) upper = np.mean(np.max(cf, axis=1) / cls_cnt) print('upper bound: {}'.format(upper)) print('-----Evaluation is finished------') print('Class Accuracy {:.02f}%'.format(np.mean(cls_acc) * 100)) print('Overall Prec@1 {:.02f}% Prec@5 {:.02f}%'.format(top1.avg, top5.avg))
def main(): global args, best_prec1 global crop_size args = parser.parse_args() num_class, train_list, val_list, args.root_path, prefix = dataset_config.return_dataset( args.dataset, args.modality) num_class = 1 if args.train_list == "": args.train_list = train_list if args.val_list == "": args.val_list = val_list full_arch_name = args.arch if args.shift: full_arch_name += '_shift{}_{}'.format(args.shift_div, args.shift_place) if args.concat != "": full_arch_name += '_concat{}'.format(args.concat) if args.temporal_pool: full_arch_name += '_tpool' args.store_name += '_'.join([ 'TSM', args.dataset, args.modality, full_arch_name, args.consensus_type, 'lr%.5f' % args.lr, 'dropout%.2f' % args.dropout, 'wd%.5f' % args.weight_decay, 'batch%d' % args.batch_size, 'segment%d' % args.num_segments, 'e{}'.format(args.epochs) ]) if args.data_fuse: args.store_name += '_fuse' if args.extra_temporal_modeling: args.store_name += '_extra' if args.tune_from is not None: args.store_name += '_finetune' if args.pretrain != 'imagenet': args.store_name += '_{}'.format(args.pretrain) if args.lr_type != 'step': args.store_name += '_{}'.format(args.lr_type) if args.dense_sample: args.store_name += '_dense' if args.non_local > 0: args.store_name += '_nl' if args.clipnums: args.store_name += "_clip{}".format(args.clipnums) if args.suffix is not None: args.store_name += '_{}'.format(args.suffix) print('storing name: ' + args.store_name) check_rootfolders() if args.prune in ['input', 'inout'] and args.tune_from: sd = torch.load(args.tune_from) sd = sd['state_dict'] sd = input_dim_L2distance(sd, args.shift_div) model = TSN( num_class, args.num_segments, args.modality, base_model=args.arch, new_length=2 if args.data_fuse else None, consensus_type=args.consensus_type, dropout=args.dropout, img_feature_dim=args.img_feature_dim, partial_bn=not args.no_partialbn, pretrain=args.pretrain, is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place, fc_lr5=not (args.tune_from and args.dataset in args.tune_from), temporal_pool=args.temporal_pool, non_local=args.non_local, concat=args.concat, extra_temporal_modeling=args.extra_temporal_modeling, prune_list=[prune_conv1in_list, prune_conv1out_list], is_prune=args.prune, ) print(model) #summary(model, torch.zeros((16, 24, 224, 224))) #exit(1) if args.dataset == 'ucf101': #twice sample & full resolution twice_sample = True crop_size = model.scale_size #256 x 256 else: twice_sample = False crop_size = model.crop_size #224 x 224 crop_size = 256 scale_size = model.scale_size input_mean = model.input_mean input_std = model.input_std policies = model.get_optim_policies(args.concat) #print(type(policies)) #print(policies) #exit() train_augmentation = model.get_augmentation( flip=False if 'something' in args.dataset or 'jester' in args.dataset or 'nvgesture' in args.dataset else True) model = torch.nn.DataParallel(model).cuda() if args.resume: if args.temporal_pool: # early temporal pool so that we can load the state_dict make_temporal_pool(model.module.base_model, args.num_segments) if os.path.isfile(args.resume): print(("=> loading checkpoint '{}'".format(args.resume))) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print(("=> loaded checkpoint '{}' (epoch {})".format( args.evaluate, checkpoint['epoch']))) else: print(("=> no checkpoint found at '{}'".format(args.resume))) if args.tune_from: print(("=> fine-tuning from '{}'".format(args.tune_from))) tune_from_list = args.tune_from.split(',') sd = torch.load(tune_from_list[0]) sd = sd['state_dict'] model_dict = model.state_dict() replace_dict = [] for k, v in sd.items(): if k not in model_dict and k.replace('.net', '') in model_dict: print('=> Load after remove .net: ', k) replace_dict.append((k, k.replace('.net', ''))) for k, v in model_dict.items(): if k not in sd and k.replace('.net', '') in sd: print('=> Load after adding .net: ', k) replace_dict.append((k.replace('.net', ''), k)) for k, v in model_dict.items(): if k not in sd and k.replace('.prune', '') in sd: print('=> Load after adding .prune: ', k) replace_dict.append((k.replace('.prune', ''), k)) if args.prune in ['input', 'inout']: sd = adjust_para_shape_prunein(sd, model_dict) if args.prune in ['output', 'inout']: sd = adjust_para_shape_pruneout(sd, model_dict) if args.concat != "" and "concat" not in tune_from_list[0]: sd = adjust_para_shape_concat(sd, model_dict) for k, k_new in replace_dict: sd[k_new] = sd.pop(k) keys1 = set(list(sd.keys())) keys2 = set(list(model_dict.keys())) set_diff = (keys1 - keys2) | (keys2 - keys1) print('#### Notice: keys that failed to load: {}'.format(set_diff)) if args.dataset not in tune_from_list[0]: # new dataset print('=> New dataset, do not load fc weights') sd = {k: v for k, v in sd.items() if 'fc' not in k} if args.modality != 'Flow' and 'Flow' in tune_from_list[0]: sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k} #print(sd.keys()) #print("*"*50) #print(model_dict.keys()) for k, v in list(sd.items()): if k not in model_dict: sd.pop(k) sd.pop("module.base_model.embedding.weight") model_dict.update(sd) model.load_state_dict(model_dict) if args.temporal_pool and not args.resume: make_temporal_pool(model.module.base_model, args.num_segments) decoder = TransformerModel().cuda() if args.decoder_resume: decoder_chkpoint = torch.load(args.decoder_resume) decoder.load_state_dict(decoder_chkpoint["state_dict"]) print("decoder parameters = ", decoder.parameters()) policies.append({ "params": decoder.parameters(), "lr_mult": 5, "decay_mult": 1, "name": "Attndecoder_weight" }) cudnn.benchmark = True optimizer = torch.optim.SGD(policies, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # Data loading code if args.modality != 'RGBDiff': normalize = GroupNormalize(input_mean, input_std) else: normalize = IdentityTransform() if args.modality in ['RGB']: data_length = 1 elif args.modality in ['Depth']: data_length = 1 elif args.modality in ['Flow', 'RGBDiff']: data_length = 5 ''' dataRoot = r"/home/share/YouCook/downloadVideo" for dirPath, dirnames, filenames in os.walk(dataRoot): for filename in filenames: print(os.path.join(dirPath,filename) +"is {}".format("exist" if os.path.isfile(os.path.join(dirPath,filename))else "NON")) train_data = torchvision.io.read_video(os.path.join(dirPath,filename),start_pts=0,end_pts=1001, ) tmp = torchvision.io.read_video_timestamps(os.path.join(dirPath,filename),) print(tmp) print(len(tmp[0])) print(train_data[0].size()) exit() exit() ''' ''' train_loader = torch.utils.data.DataLoader( TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, transform=torchvision.transforms.Compose([ train_augmentation, Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize, ]), dense_sample=args.dense_sample, data_fuse = args.data_fuse), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) # prevent something not % n_GPU val_loader = torch.utils.data.DataLoader( TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, random_shift=False, transform=torchvision.transforms.Compose([ GroupScale(int(scale_size)), GroupCenterCrop(crop_size), Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize, ]), dense_sample=args.dense_sample, twice_sample=twice_sample, data_fuse = args.data_fuse), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) ''' #global trainDataloader, valDataloader, train_loader, val_loader trainDataloader = YouCookDataSetRcg(args.root_path, args.train_list,train=True,inputsize=crop_size,hasPreprocess = False,\ clipnums=args.clipnums, hasWordIndex = True,) valDataloader = YouCookDataSetRcg(args.root_path, args.val_list,val=True,inputsize=crop_size,hasPreprocess = False,\ #clipnums=args.clipnums, hasWordIndex = True,) #print(trainDataloader._getMode()) #print(valDataloader._getMode()) #exit() train_loader = torch.utils.data.DataLoader(trainDataloader, #shuffle=True, ) val_loader = torch.utils.data.DataLoader(valDataloader) index2wordDict = trainDataloader.getIndex2wordDict() #print(train_loader is val_loader) #print(trainDataloader._getMode()) #print(valDataloader._getMode()) #print(trainDataloader._getMode()) #print(valDataloader._getMode()) #print(len(train_loader)) #exit() # define loss function (criterion) and optimizer if args.loss_type == 'nll': criterion = torch.nn.NLLLoss().cuda() elif args.loss_type == "MSELoss": criterion = torch.nn.MSELoss().cuda() elif args.loss_type == "BCELoss": #print("BCELoss") criterion = torch.nn.BCELoss().cuda() elif args.loss_type == "CrossEntropyLoss": criterion = torch.nn.CrossEntropyLoss().cuda() else: raise ValueError("Unknown loss type") for group in policies: print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format( group['name'], len(group['params']), group['lr_mult'], group['decay_mult']))) if args.evaluate: validate(val_loader, model, criterion, 0) return log_training = open( os.path.join(args.root_log, args.store_name, 'log.csv'), 'w') with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f: f.write(str(args)) #print(os.path.join(args.root_log, args.store_name, 'args.txt')) #exit() tf_writer = SummaryWriter( log_dir=os.path.join(args.root_log, args.store_name)) for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps) #print("265") # train for one epoch ###### #print(trainDataloader._getMode()) #print(valDataloader._getMode()) train(train_loader, model, decoder, criterion, optimizer, epoch, log_training, tf_writer, index2wordDict) ###### #print("268") # evaluate on validation set if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: prec1 = validate(val_loader, model, decoder, criterion, epoch, log_training, tf_writer, index2wordDict=index2wordDict) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch) output_best = 'Best Prec@1: %.3f\n' % (best_prec1) #print(output_best) log_training.write(output_best + '\n') log_training.flush() save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_prec1': best_prec1, }, is_best) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': decoder.state_dict(), 'optimizer': optimizer.state_dict(), 'best_prec1': best_prec1, }, is_best, filename="decoder") else: save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_prec1': best_prec1, }, False) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': decoder.state_dict(), 'optimizer': optimizer.state_dict(), 'best_prec1': best_prec1, }, is_best, filename="decoder") #break print("test pass")
def main(): global args, best_prec1 args = parser.parse_args() num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset( args.dataset, args.modality) full_arch_name = args.arch if args.temporal_pool: full_arch_name += '_tpool' args.store_name = '_'.join([ 'I3D', args.dataset, full_arch_name, 'batch{}'.format(args.batch_size), 'wd{}'.format(args.weight_decay), args.consensus_type, 'segment%d' % args.num_segments, 'e{}'.format(args.epochs), 'dropout{}'.format(args.dropout), args.pretrain, 'lr{}'.format(args.lr), '_warmup{}'.format(args.warmup) ]) if args.lr_type != 'step': args.store_name += '_{}'.format(args.lr_type) else: step_str = [str(int(x)) for x in args.lr_steps] args.store_name += '_step' + '_'.join(step_str) if args.dense_sample: args.store_name += '_dense' if args.spatial_dropout: sigmoid_layer_str = '_'.join(args.sigmoid_layer) args.store_name += '_spatial_drop3d_{}_group{}_layer{}'.format( args.sigmoid_thres, args.sigmoid_group, sigmoid_layer_str) if args.sigmoid_random: args.store_name += '_RandomSigmoid' if args.non_local > 0: args.store_name += '_nl' if args.suffix is not None: args.store_name += '_{}'.format(args.suffix) print('storing name: ' + args.store_name) check_rootfolders() model = i3d(num_class, args.num_segments, base_model=args.arch, consensus_type=args.consensus_type, dropout=args.dropout, spatial_dropout=args.spatial_dropout, sigmoid_thres=args.sigmoid_thres, sigmoid_group=args.sigmoid_group, sigmoid_random=args.sigmoid_random, sigmoid_layer=args.sigmoid_layer, img_feature_dim=args.img_feature_dim, partial_bn=not args.no_partialbn, pretrain=args.pretrain) crop_size = model.crop_size scale_size = model.scale_size input_mean = model.input_mean input_std = model.input_std train_augmentation = model.get_augmentation( flip=False if 'something' in args.dataset or 'jester' in args.dataset else True) model = torch.nn.DataParallel(model, device_ids=list(range(args.gpus))).cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.resume: if args.temporal_pool: # early temporal pool so that we can load the state_dict make_temporal_pool(model.module.base_model, args.num_segments) if os.path.isfile(args.resume): print(("=> loading checkpoint '{}'".format(args.resume))) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print(("=> loaded checkpoint '{}' (epoch {})".format( args.evaluate, checkpoint['epoch']))) else: print(("=> no checkpoint found at '{}'".format(args.resume))) if args.tune_from: print(("=> fine-tuning from '{}'".format(args.tune_from))) sd = torch.load(args.tune_from) sd = sd['state_dict'] model_dict = model.state_dict() replace_dict = [] for k, v in sd.items(): if k not in model_dict and k.replace('.net', '') in model_dict: print('=> Load after remove .net: ', k) replace_dict.append((k, k.replace('.net', ''))) for k, v in model_dict.items(): if k not in sd and k.replace('.net', '') in sd: print('=> Load after adding .net: ', k) replace_dict.append((k.replace('.net', ''), k)) for k, k_new in replace_dict: sd[k_new] = sd.pop(k) keys1 = set(list(sd.keys())) keys2 = set(list(model_dict.keys())) set_diff = (keys1 - keys2) | (keys2 - keys1) print('#### Notice: keys that failed to load: {}'.format(set_diff)) if args.dataset not in args.tune_from: # new dataset print('=> New dataset, do not load fc weights') sd = {k: v for k, v in sd.items() if 'fc' not in k} if args.modality == 'Flow' and 'Flow' not in args.tune_from: sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k} model_dict.update(sd) model.load_state_dict(model_dict) if args.temporal_pool and not args.resume: make_temporal_pool(model.module.base_model, args.num_segments) cudnn.benchmark = True # Data loading code normalize = GroupNormalize(input_mean, input_std) train_loader = torch.utils.data.DataLoader( TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments, image_tmpl=prefix, transform=torchvision.transforms.Compose([ GroupScale((256, 340)), train_augmentation, Stack('3D'), ToTorchFormatTensor(), normalize, ]), dense_sample=args.dense_sample), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) # prevent something not % n_GPU val_loader = torch.utils.data.DataLoader(TSNDataSet( args.root_path, args.val_list, num_segments=args.num_segments, image_tmpl=prefix, random_shift=False, transform=torchvision.transforms.Compose([ GroupScale(int(scale_size)), GroupCenterCrop(crop_size), Stack('3D'), ToTorchFormatTensor(), normalize, ]), dense_sample=args.dense_sample), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) # define loss function (criterion) and optimizer if args.loss_type == 'nll': criterion = torch.nn.BCEWithLogitsLoss().cuda() else: raise ValueError("Unknown loss type") if args.evaluate: validate(val_loader, model, criterion, 0) return log_training = open( os.path.join(args.root_log, args.store_name, 'log.csv'), 'w') with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f: f.write(str(args)) tf_writer = SummaryWriter( log_dir=os.path.join(args.root_log, args.store_name)) for epoch in range(args.start_epoch, args.epochs): # adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps) # train for one epoch train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer, args.warmup, args.lr_type, args.lr_steps) # evaluate on validation set if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: prec1 = validate(val_loader, model, criterion, epoch, log_training, tf_writer) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) tf_writer.add_scalar('acc/test_mAP_best', best_prec1, epoch) output_best = 'Best mAP: %.3f\n' % (best_prec1) print(output_best) log_training.write(output_best + '\n') log_training.flush() save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_prec1': best_prec1, }, is_best)
def main(): global args, best_prec1 args = parser.parse_args() num_class = opts.num_class full_arch_name = args.arch if args.shift: full_arch_name += '_shift{}_{}'.format(args.shift_div, args.shift_place) if args.temporal_pool: full_arch_name += '_tpool' args.store_name = '_'.join([ 'TSM', args.dataset, args.modality, full_arch_name, args.consensus_type, 'segment%d' % args.num_segments, 'e{}'.format(args.epochs) ]) if args.pretrain != 'imagenet': args.store_name += '_{}'.format(args.pretrain) if args.lr_type != 'step': args.store_name += '_{}'.format(args.lr_type) if args.non_local > 0: args.store_name += '_nl' print('storing name: ' + args.store_name) # check_rootfolders() model = TSN(num_class, args.num_segments, args.modality, base_model=args.arch, consensus_type=args.consensus_type, dropout=args.dropout, img_size=args.img_size, partial_bn=not args.no_partialbn, pretrain=args.pretrain, is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place, fc_lr5=True, temporal_pool=args.temporal_pool, non_local=args.non_local) crop_size = model.crop_size scale_size = model.scale_size input_mean = model.input_mean input_std = model.input_std policies = model.get_optim_policies() train_augmentation = model.get_augmentation( flip=False if 'something' in args.dataset or 'jester' in args.dataset else True) model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda() model.apply(weights_init) optimizer = torch.optim.SGD(policies, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.resume: if args.temporal_pool: # early temporal pool so that we can load the state_dict make_temporal_pool(model.module.base_model, args.num_segments) if os.path.isfile(args.resume): print(("=> loading checkpoint '{}'".format(args.resume))) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] # best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict'], strict=False) # optimizer.load_state_dict(checkpoint['optimizer']) print(("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch']))) else: print(("=> no checkpoint found at '{}'".format(args.resume))) if args.temporal_pool and not args.resume: make_temporal_pool(model.module.base_model, args.num_segments) cudnn.benchmark = True # Data loading code if args.modality != 'RGBDiff': normalize = GroupNormalize(input_mean, input_std) else: normalize = IdentityTransform() fr_r = open(opts.NUM_LABEL_R, 'r+') w2n = eval(fr_r.read()) fr_r.close() fr = open(opts.NUM_LABEL, 'r+') n2w = eval(fr.read()) fr.close() train_loader, val_loader, test_loader = None, None, None if args.mode != 'test': lip_dict, video_list = opts.file_deal(opts.TRAIN_DATA, w2n) train_num = int(len(video_list) * 0.95) train_loader = torch.utils.data.DataLoader(TSNDataSet( opts.TRAIN_DATA, args.mode, num_segments=args.num_segments, img_size=args.img_size, lip_dict=lip_dict, video_list=video_list[:train_num], transform=torchvision.transforms.Compose([ train_augmentation, GroupMultiScaleCrop(args.img_size, [1, .875, .75, .66]), Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor( div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize, ])), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(TSNDataSet( opts.TRAIN_DATA, args.mode, num_segments=args.num_segments, img_size=args.img_size, lip_dict=lip_dict, video_list=video_list[train_num:], transform=torchvision.transforms.Compose([ GroupScale(int(scale_size)), GroupCenterCrop(crop_size), Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor( div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize, ])), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) else: lip_dict, video_list = opts.file_deal(opts.TEST_DATA, w2n) test_loader = torch.utils.data.DataLoader(TSNDataSet_infer( opts.TEST_DATA, num_segments=args.num_segments, img_size=args.img_size, lip_dict=lip_dict, video_list=video_list, transform=torchvision.transforms.Compose([ GroupScale(int(scale_size)), GroupCenterCrop(crop_size), Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor( div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize, ])), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) # define loss function (criterion) and optimizer if args.loss_type == 'nll': criterion = torch.nn.CrossEntropyLoss().cuda() else: raise ValueError("Unknown loss type") for group in policies: print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format( group['name'], len(group['params']), group['lr_mult'], group['decay_mult']))) if args.mode == 'test': if args.sub == 'sub': inference(test_loader, model, n2w) else: inferencefusion(test_loader, model, n2w) return # 开始训练 for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps) # train for one epoch train(train_loader, model, criterion, optimizer, epoch) # evaluate on validation set if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: prec1 = validate(val_loader, model, criterion, epoch, n2w) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) output_best = 'Best Prec@1: %.3f\n' % (best_prec1) print(output_best) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_prec1': best_prec1, }, is_best)
def v_train(train_loader, val_loader, model, num_class, vnet, criterion, valcriterion, optimizer, epoch, log, tf_writer): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() val_loader_iter = iter(val_loader) if args.no_partialbn: model.module.partialBN(False) else: model.module.partialBN(True) # switch to train mode model.train() end = time.time() for i, (input, target) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) target = target.cuda() input_var = torch.autograd.Variable(input) target_var = torch.autograd.Variable(target) vnet_temp = VNet(1, 100, 1).cuda() optimizer_vnet_temp = torch.optim.Adam(vnet_temp.params(), 1e-3, weight_decay=1e-4) vnet_temp.load_state_dict(vnet.state_dict()) v_model = v_TSN( num_class, args.num_segments, args.modality, base_model=args.arch, consensus_type=args.consensus_type, dropout=args.dropout, img_feature_dim=args.img_feature_dim, partial_bn=not args.no_partialbn, pretrain=args.pretrain, is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place, fc_lr5=not (args.tune_from and args.dataset in args.tune_from), temporal_pool=args.temporal_pool, non_local=args.non_local, print_spec=False) v_model = torch.nn.DataParallel(v_model, device_ids=args.gpus).cuda() if args.temporal_pool and not args.resume: make_temporal_pool(v_model.module.base_model, args.num_segments) v_model.load_state_dict(model.state_dict()) # compute output output = v_model(input_var) # loss = criterion(output, target_var) cost = criterion(output, target_var) cost_v = torch.reshape(cost, (-1, 1)) v_lambda = vnet_temp(cost_v.data) l_f_v = torch.sum(cost_v * v_lambda) / len(cost_v) v_model.zero_grad() grads = torch.autograd.grad(l_f_v, (v_model.module.params()), create_graph=True) # to be modified v_lr = args.lr * ((0.1**int(epoch >= 80)) * (0.1**int(epoch >= 100))) v_model.module.update_params(lr_inner=v_lr, source_params=grads) del grads # phase 2. pixel weights step try: inputs_val, targets_val = next(val_loader_iter) # 拿一个val set图片 except StopIteration: val_loader_iter = iter(val_loader) inputs_val, targets_val = next(val_loader_iter) # inputs_val, targets_val = sample_val['image'], sample_val['label'] inputs_val, targets_val = inputs_val.cuda(), targets_val.cuda() y_g_hat = v_model(inputs_val) l_g_meta = valcriterion(y_g_hat, targets_val) # val loss optimizer_vnet_temp.zero_grad() l_g_meta.backward() optimizer_vnet_temp.step() vnet.load_state_dict(vnet_temp.state_dict()) # phase 1. network weight step (w) output = model(input_var) cost = criterion(output, target) cost_v = torch.reshape(cost, (-1, 1)) with torch.no_grad(): v_new = vnet(cost_v) loss = torch.sum(cost_v * v_new) / len(cost_v) optimizer.zero_grad() loss.backward() optimizer.step() # measure accuracy and record loss prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(prec1.item(), input.size(0)) top5.update(prec5.item(), input.size(0)) # compute gradient and do SGD step # loss.backward() if args.clip_gradient is not None: total_norm = clip_grad_norm_(model.parameters(), args.clip_gradient) # optimizer.step() # optimizer.zero_grad() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: output = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( epoch, i, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, top5=top5, lr=optimizer.param_groups[-1]['lr'] * 0.1)) # TODO print(output, end=" ") for n, p in vnet.named_params(vnet): print("vnet param: ", n, p[0].item()) break log.write(output + '\n') log.flush() tf_writer.add_scalar('loss/train', losses.avg, epoch) tf_writer.add_scalar('acc/train_top1', top1.avg, epoch) tf_writer.add_scalar('acc/train_top5', top5.avg, epoch) tf_writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], epoch)
def __init__(self, weightPath, segments, crops, fullSize=True): self.weightPath = weightPath self.segments = segments self.crops = crops self.fullSize = fullSize self.is_shift, shift_div, shift_place = parse_shift_option_from_log_name(self.weightPath) if 'RGB' in self.weightPath: self.modality = 'RGB' else: self.modality = 'Flow' this_arch = self.weightPath.split('TSM_')[1].split('_')[2] self.num_class = 400 print('=> shift: {}, shift_div: {}, shift_place: {}'.format(self.is_shift, shift_div, shift_place)) self.net = TSN(self.num_class, self.segments if self.is_shift else 1, self.modality, base_model=this_arch, consensus_type="avg", img_feature_dim=256, pretrain="imagenet", is_shift=self.is_shift, shift_div=shift_div, shift_place=shift_place, non_local='_nl' in self.weightPath, ) if 'tpool' in self.weightPath: from ops.temporal_shift import make_temporal_pool make_temporal_pool(self.net.base_model, self.segments) # since DataParallel checkpoint = torch.load(self.weightPath) checkpoint = checkpoint['state_dict'] base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())} replace_dict = {'base_model.classifier.weight': 'new_fc.weight', 'base_model.classifier.bias': 'new_fc.bias', } for k, v in replace_dict.items(): if k in base_dict: base_dict[v] = base_dict.pop(k) self.net.load_state_dict(base_dict) input_size = self.net.scale_size if self.fullSize else self.net.input_size if self.crops == 1: cropping = torchvision.transforms.Compose([ GroupScale(self.net.scale_size), GroupCenterCrop(input_size), ]) elif self.crops == 3: # do not flip, so only 5 crops cropping = torchvision.transforms.Compose([ GroupFullResSample(input_size, self.net.scale_size, flip=False) ]) elif self.crops == 5: # do not flip, so only 5 crops cropping = torchvision.transforms.Compose([ GroupOverSample(input_size, self.net.scale_size, flip=False) ]) elif self.crops == 10: cropping = torchvision.transforms.Compose([ GroupOverSample(input_size, self.net.scale_size) ]) else: raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(self.crops)) self.transform = torchvision.transforms.Compose([ cropping, Stack(roll=(this_arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor(div=(this_arch not in ['BNInception', 'InceptionV3'])), GroupNormalize(self.net.input_mean, self.net.input_std)]) if self.modality == 'RGB': self.length = 3 elif self.modality == 'Flow': self.length = 10 elif self.modality == 'RGBDiff': self.length = 18 self.net = self.net.cuda() self.net.eval()
def main(): global args, best_prec1 args = parser.parse_args() #num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset, # args.modality) num_class = 21 args.train_list = "/home/jzwang/code/Video_3D/movienet/data/movie/movie_train.txt" args.val_list = "/home/jzwang/code/Video_3D/movienet/data/movie/movie_test.txt" args.root_path = "" prefix = "frame_{:04d}.jpg" full_arch_name = args.arch if args.shift: full_arch_name += '_shift{}_{}'.format(args.shift_div, args.shift_place) if args.temporal_pool: full_arch_name += '_tpool' args.store_name = '_'.join( ['TSM', args.dataset, args.modality, full_arch_name, args.consensus_type, 'segment%d' % args.num_segments, 'e{}'.format(args.epochs)]) if args.pretrain != 'imagenet': args.store_name += '_{}'.format(args.pretrain) if args.lr_type != 'step': args.store_name += '_{}'.format(args.lr_type) if args.dense_sample: args.store_name += '_dense' if args.non_local > 0: args.store_name += '_nl' if args.suffix is not None: args.store_name += '_{}'.format(args.suffix) print('storing name: ' + args.store_name) #check_rootfolders() model = TSN(num_class, args.num_segments, args.modality, base_model=args.arch, consensus_type=args.consensus_type, dropout=args.dropout, img_feature_dim=args.img_feature_dim, partial_bn=not args.no_partialbn, pretrain=args.pretrain, is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place, fc_lr5=not (args.tune_from and args.dataset in args.tune_from), temporal_pool=args.temporal_pool, non_local=args.non_local) crop_size = model.crop_size scale_size = model.scale_size input_mean = model.input_mean input_std = model.input_std policies = model.get_optim_policies() train_augmentation = model.get_augmentation(flip=False if 'something' in args.dataset or 'jester' in args.dataset else True) model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda() optimizer = torch.optim.SGD(policies, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.resume: if args.temporal_pool: # early temporal pool so that we can load the state_dict make_temporal_pool(model.module.base_model, args.num_segments) if os.path.isfile(args.resume): print(("=> loading checkpoint '{}'".format(args.resume))) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] #best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print(("=> loaded checkpoint '{}' (epoch {})" .format(args.evaluate, checkpoint['epoch']))) else: print(("=> no checkpoint found at '{}'".format(args.resume))) if args.tune_from: print(("=> fine-tuning from '{}'".format(args.tune_from))) sd = torch.load(args.tune_from) sd = sd['state_dict'] model_dict = model.state_dict() replace_dict = [] for k, v in sd.items(): if k not in model_dict and k.replace('.net', '') in model_dict: print('=> Load after remove .net: ', k) replace_dict.append((k, k.replace('.net', ''))) for k, v in model_dict.items(): if k not in sd and k.replace('.net', '') in sd: print('=> Load after adding .net: ', k) replace_dict.append((k.replace('.net', ''), k)) for k, k_new in replace_dict: sd[k_new] = sd.pop(k) keys1 = set(list(sd.keys())) keys2 = set(list(model_dict.keys())) set_diff = (keys1 - keys2) | (keys2 - keys1) print('#### Notice: keys that failed to load: {}'.format(set_diff)) if args.dataset not in args.tune_from: # new dataset print('=> New dataset, do not load fc weights') sd = {k: v for k, v in sd.items() if 'fc' not in k} if args.modality == 'Flow' and 'Flow' not in args.tune_from: sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k} model_dict.update(sd) model.load_state_dict(model_dict) if args.temporal_pool and not args.resume: make_temporal_pool(model.module.base_model, args.num_segments) cudnn.benchmark = True # Data loading code if args.modality != 'RGBDiff': normalize = GroupNormalize(input_mean, input_std) else: normalize = IdentityTransform() if args.modality == 'RGB': data_length = 1 elif args.modality in ['Flow', 'RGBDiff']: data_length = 5 train_loader = torch.utils.data.DataLoader( TSNDataSetMovie("", args.train_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl="frame_{:04d}.jpg" if args.modality in ["RGB", "RGBDiff"] else args.flow_prefix+"{}_{:05d}.jpg", transform=torchvision.transforms.Compose([ train_augmentation, Stack(roll=args.arch == 'BNInception'), ToTorchFormatTensor(div=args.arch != 'BNInception'), normalize, ])), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader( TSNDataSetMovie("", args.val_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl="frame_{:04d}.jpg" if args.modality in ["RGB", "RGBDiff"] else args.flow_prefix+"{}_{:05d}.jpg", random_shift=False, transform=torchvision.transforms.Compose([ GroupScale(int(scale_size)), GroupCenterCrop(crop_size), Stack(roll=args.arch == 'BNInception'), ToTorchFormatTensor(div=args.arch != 'BNInception'), normalize, ])), batch_size=int(args.batch_size/2), shuffle=False, num_workers=args.workers, pin_memory=True) # define loss function (criterion) and optimizer criterion = torch.nn.BCEWithLogitsLoss().cuda() for group in policies: print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format( group['name'], len(group['params']), group['lr_mult'], group['decay_mult']))) optimizer = torch.optim.SGD(policies, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) zero_time = time.time() best_map = 0 print ('Start training...') for epoch in range(args.start_epoch, args.epochs): valloss, mAP, wAP, output_mtx = validate(val_loader, model, criterion) adjust_learning_rate(optimizer, epoch, args.lr_steps) np.save("testnew.npy", output_mtx) print("saving down") # train for one epoch start_time = time.time() trainloss = train(train_loader, model, criterion, optimizer, epoch) print('Traing loss %4f Epoch %d'% (trainloss, epoch)) if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: valloss, mAP, wAP, output_mtx = validate(val_loader, model, criterion) end_time = time.time() epoch_time = end_time - start_time total_time = end_time - zero_time print ('Total time used: %s Epoch %d time uesd: %s'%( str(datetime.timedelta(seconds=int(total_time))), epoch, str(datetime.timedelta(seconds=int(epoch_time))))) print ('Train loss: {0:.4f} val loss: {1:.4f} mAP: {2:.4f} wAP: {3:.4f}'.format( trainloss, valloss, mAP, wAP)) # evaluate on validation set is_best = mAP > best_map #if mAP > best_map: #best_map = mAP # checkpoint_name = "%04d_%s" % (epoch+1, "checkpoint.pth.tar") checkpoint_name = "best_checkpoint.pth.tar" save_checkpoint({ 'epoch': epoch+1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, is_best, epoch) np.save("testnew.npy", output_mtx) print("saving down") with open(args.record_path, 'a') as file: file.write('Epoch:[{0}]' 'Train loss: {1:.4f} val loss: {2:.4f} map: {3:.4f}\n'.format( epoch+1, trainloss, valloss, mAP)) print ('************ Done!... ************')
def main(): global args, best_prec1 args = parser.parse_args() num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset, args.modality) model = TSN(num_class, args.num_segments, args.modality, base_model=args.arch, consensus_type=args.consensus_type, dropout=args.dropout, img_feature_dim=args.img_feature_dim, partial_bn=not args.no_partialbn, pretrain=args.pretrain, is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place, fc_lr5=not (args.tune_from and args.dataset in args.tune_from), temporal_pool=args.temporal_pool, non_local=args.non_local, cca3d = args.cca3d ) crop_size = model.crop_size scale_size = model.scale_size input_mean = model.input_mean input_std = model.input_std policies = model.get_optim_policies() train_augmentation = model.get_augmentation(flip=False if 'something' in args.dataset or 'jester' in args.dataset else True) model = model.cuda() if args.resume: if args.temporal_pool: # early temporal pool so that we can load the state_dict make_temporal_pool(model.module.base_model, args.num_segments) if os.path.isfile(args.resume): print(("=> loading checkpoint '{}'".format(args.resume))) checkpoint = torch.load(args.resume,map_location='cuda:0') parallel_state_dict = checkpoint['state_dict'] cpu_state_dict={} for k,v in parallel_state_dict.items(): cpu_state_dict[k[len('module.'):]] = v model.load_state_dict(cpu_state_dict) print(("=> loaded checkpoint '{}' (epoch {})".format(args.evaluate, checkpoint['epoch']))) else: print(("=> no checkpoint found at '{}'".format(args.resume))) # Data loading code if args.modality != 'RGBDiff': normalize = GroupNormalize(input_mean, input_std) else: normalize = IdentityTransform() if args.modality == 'RGB': data_length = 1 elif args.modality in ['Flow', 'RGBDiff']: data_length = 5 preprocess=torchvision.transforms.Compose([ GroupScale(int(scale_size)), GroupCenterCrop(crop_size), Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize, ]) val_loader = torch.utils.data.DataLoader( TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, random_shift=False, transform = preprocess, dense_sample=args.dense_sample), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=False) norm_param = (input_mean, input_std) cam_process(val_loader, model,norm_param)
def load_model(weights): global num_class is_shift, shift_div, shift_place = parse_shift_option_from_log_name( weights) if 'RGB' in weights: modality = 'RGB' elif 'Depth' in weights: modality = 'Depth' else: modality = 'Flow' if 'concatAll' in weights: concat = "All" elif "concatFirst" in weights: concat = "First" else: concat = "" if 'extra' in this_weights: extra_temporal_modeling = True args.prune = "" if 'conv1d' in weights: args.crop_fusion_type = "conv1d" else: args.crop_fusion_type = "avg" this_arch = weights.split('TSM_')[1].split('_')[2] modality_list.append(modality) num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset( args.dataset, modality) print('=> shift: {}, shift_div: {}, shift_place: {}'.format( is_shift, shift_div, shift_place)) net = TSN(num_class, int(args.test_segments) if is_shift else 1, modality, base_model=this_arch, consensus_type=args.crop_fusion_type, img_feature_dim=args.img_feature_dim, pretrain=args.pretrain, is_shift=is_shift, shift_div=shift_div, shift_place=shift_place, non_local='_nl' in weights, concat=concat, extra_temporal_modeling=extra_temporal_modeling, prune_list=[prune_conv1in_list, prune_conv1out_list], is_prune=args.prune) if 'tpool' in weights: from ops.temporal_shift import make_temporal_pool make_temporal_pool(net.base_model, args.test_segments) # since DataParallel checkpoint = torch.load(weights) checkpoint = checkpoint['state_dict'] # base_dict = {('base_model.' + k).replace('base_model.fc', 'new_fc'): v for k, v in list(checkpoint.items())} base_dict = { '.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items()) } replace_dict = { 'base_model.classifier.weight': 'new_fc.weight', 'base_model.classifier.bias': 'new_fc.bias', } for k, v in replace_dict.items(): if k in base_dict: base_dict[v] = base_dict.pop(k) net.load_state_dict(base_dict) input_size = net.scale_size if args.full_res else net.input_size if args.test_crops == 1: cropping = torchvision.transforms.Compose([ GroupScale(net.scale_size), GroupCenterCrop(input_size), ]) elif args.test_crops == 3: # do not flip, so only 5 crops cropping = torchvision.transforms.Compose( [GroupFullResSample(input_size, net.scale_size, flip=False)]) elif args.test_crops == 5: # do not flip, so only 5 crops cropping = torchvision.transforms.Compose( [GroupOverSample(input_size, net.scale_size, flip=False)]) elif args.test_crops == 10: cropping = torchvision.transforms.Compose( [GroupOverSample(input_size, net.scale_size)]) else: raise ValueError( "Only 1, 5, 10 crops are supported while we got {}".format( args.test_crops)) transform = torchvision.transforms.Compose([ cropping, Stack(roll=(this_arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor( div=(this_arch not in ['BNInception', 'InceptionV3'])), GroupNormalize(net.input_mean, net.input_std), ]) if args.gpus is not None: devices = [args.gpus[i] for i in range(args.workers)] else: devices = list(range(args.workers)) net = torch.nn.DataParallel(net.cuda()) return is_shift, net, transform