def get_data_loaders(model, prefix, args): train_transform_flip = torchvision.transforms.Compose([ model.module.get_augmentation(flip=True), Stack(roll=("BNInc" in args.arch)), ToTorchFormatTensor(div=("BNInc" not in args.arch)), GroupNormalize(model.module.input_mean, model.module.input_std), ]) train_transform_nofl = torchvision.transforms.Compose([ model.module.get_augmentation(flip=False), Stack(roll=("BNInc" in args.arch)), ToTorchFormatTensor(div=("BNInc" not in args.arch)), GroupNormalize(model.module.input_mean, model.module.input_std), ]) val_transform = torchvision.transforms.Compose([ GroupScale(int(model.module.scale_size)), GroupCenterCrop(model.module.crop_size), Stack(roll=("BNInc" in args.arch)), ToTorchFormatTensor(div=("BNInc" not in args.arch)), GroupNormalize(model.module.input_mean, model.module.input_std), ]) train_dataset = TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments, image_tmpl=prefix, transform=(train_transform_flip, train_transform_nofl), dense_sample=args.dense_sample, dataset=args.dataset, filelist_suffix=args.filelist_suffix, folder_suffix=args.folder_suffix, save_meta=args.save_meta, always_flip=args.always_flip, conditional_flip=args.conditional_flip, adaptive_flip=args.adaptive_flip) val_dataset = TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments, image_tmpl=prefix, random_shift=False, transform=(val_transform, val_transform), dense_sample=args.dense_sample, dataset=args.dataset, filelist_suffix=args.filelist_suffix, folder_suffix=args.folder_suffix, save_meta=args.save_meta) train_loader = build_dataflow(train_dataset, True, args.batch_size, args.workers, args.not_pin_memory) val_loader = build_dataflow(val_dataset, False, args.batch_size, args.workers, args.not_pin_memory) return train_loader, val_loader
def get_val_loader(model): root_path = '/home/mbc2004/datasets/Something-Something/frames/' train_list = '/home/mbc2004/datasets/Something-Something/annotations/val_videofolder.txt' num_segments = 8 modality = 'RGB' dense_sample = False batch_size = 8 #64 workers = 16 arch = 'resnet50' prefix = '{:06d}.jpg' print('#' * 20, 'NO FLIP!!!') train_augmentation = torchvision.transforms.Compose( [GroupMultiScaleCrop(224, [1, .875, .75, .66])]) return torch.utils.data.DataLoader( TSNDataSet(root_path, train_list, num_segments=num_segments, new_length=1, modality=modality, image_tmpl=prefix, transform=torchvision.transforms.Compose([ train_augmentation, Stack(roll=(arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor( div=(arch not in ['BNInception', 'InceptionV3'])), IdentityTransform(), ]), dense_sample=dense_sample), batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True, drop_last=True) # prevent something not % n_GPU
test_file = test_file if test_file is not None else val_list data_loader = torch.utils.data.DataLoader( TSNDataSet( root_path, test_file, num_segments=this_test_segments, new_length=1 if modality in ['RGB', 'RGB-flo', 'RGB-seg'] else 5, modality=modality, image_tmpl=prefix, test_mode=True, remove_missing=len(weights_list) == 1, transform=torchvision.transforms.Compose([ cropping, Stack(roll=(this_arch in ['BNInception', 'InceptionV3']), mask=(modality in ['RGB-flo', 'RGB-seg'])), ToTorchFormatTensor( div=(this_arch not in ['BNInception', 'InceptionV3'])), GroupNormalize(net.input_mean, net.input_std), ]), dense_sample=args.dense_sample, twice_sample=args.twice_sample, dense_window=args.dense_window, full_sample=args.full_sample, ipn=args.dataset == 'ipn', ipn_no_class=ipn_no_class), batch_size=args.batch_size, shuffle=False, # num_workers=args.workers, pin_memory=True, )
def main(): global args, best_prec1, TRAIN_SAMPLES args = parser.parse_args() num_class, args.train_list, args.val_list, args.test_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset, args.modality) if os.path.exists(args.test_list): args.val_list = args.test_list model = TSN(num_class, args.num_segments, args.modality, base_model=args.arch, consensus_type=args.consensus_type, dropout=args.dropout, img_feature_dim=args.img_feature_dim, partial_bn=not args.no_partialbn, pretrain=args.pretrain, is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place, fc_lr5=not (args.tune_from and args.dataset in args.tune_from), temporal_pool=args.temporal_pool, non_local=args.non_local, tin=args.tin) crop_size = args.crop_size scale_size = args.scale_size input_mean = [0.485, 0.456, 0.406] input_std = [0.229, 0.224, 0.225] print(args.gpus) model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda() if os.path.isfile(args.resume_path): print(("=> loading checkpoint '{}'".format(args.resume_path))) checkpoint = torch.load(args.resume_path) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict'], strict=False) print(("=> loaded checkpoint '{}' (epoch {})" .format(args.evaluate, checkpoint['epoch']))) else: print(("=> no checkpoint found at '{}'".format(args.resume_path))) cudnn.benchmark = True # Data loading code if args.modality != 'RGBDiff': normalize = GroupNormalize(input_mean, input_std) else: normalize = IdentityTransform() if args.modality == 'RGB': data_length = 1 elif args.modality in ['Flow', 'RGBDiff']: data_length = 5 if args.random_crops == 1: crop_aug = GroupCenterCrop(args.crop_size) elif args.random_crops == 3: crop_aug = GroupFullResSample(args.crop_size, args.scale_size, flip=False) elif args.random_crops == 5: crop_aug = GroupOverSample(args.crop_size, args.scale_size, flip=False) else: crop_aug = MultiGroupRandomCrop(args.crop_size, args.random_crops), test_dataset = TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, multi_class=args.multi_class, transform=torchvision.transforms.Compose([ GroupScale(int(args.scale_size)), crop_aug, Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize, ]), dense_sample=args.dense_sample, test_mode=True, temporal_clips=args.temporal_clips) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) test(test_loader, model, args.start_epoch)
def main(): t_start = time.time() global args, best_prec1, num_class, use_ada_framework # , model wandb.init( project="arnet-reproduce", name=args.exp_header, entity="video_channel" ) wandb.config.update(args) set_random_seed(args.random_seed) use_ada_framework = args.ada_reso_skip and args.offline_lstm_last == False and args.offline_lstm_all == False and args.real_scsampler == False if args.ablation: logger = None else: if not test_mode: logger = Logger() sys.stdout = logger else: logger = None num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset, args.data_dir) #=== #args.val_list = args.train_list #=== if args.ada_reso_skip: if len(args.ada_crop_list) == 0: args.ada_crop_list = [1 for _ in args.reso_list] if use_ada_framework: init_gflops_table() model = TSN_Ada(num_class, args.num_segments, base_model=args.arch, consensus_type=args.consensus_type, dropout=args.dropout, partial_bn=not args.no_partialbn, pretrain=args.pretrain, fc_lr5=not (args.tune_from and args.dataset in args.tune_from), args=args) crop_size = model.crop_size scale_size = model.scale_size input_mean = model.input_mean input_std = model.input_std policies = model.get_optim_policies() train_augmentation = model.get_augmentation( flip=False if 'something' in args.dataset or 'jester' in args.dataset else True) model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda() # TODO(yue) freeze some params in the policy + lstm layers if args.freeze_policy: for name, param in model.module.named_parameters(): if "lite_fc" in name or "lite_backbone" in name or "rnn" in name or "linear" in name: param.requires_grad = False if args.freeze_backbone: for name, param in model.module.named_parameters(): if "base_model" in name: param.requires_grad = False if len(args.frozen_list) > 0: for name, param in model.module.named_parameters(): for keyword in args.frozen_list: if keyword[0] == "*": if keyword[-1] == "*": # TODO middle if keyword[1:-1] in name: param.requires_grad = False print(keyword, "->", name, "frozen") else: # TODO suffix if name.endswith(keyword[1:]): param.requires_grad = False print(keyword, "->", name, "frozen") elif keyword[-1] == "*": # TODO prefix if name.startswith(keyword[:-1]): param.requires_grad = False print(keyword, "->", name, "frozen") else: # TODO exact word if name == keyword: param.requires_grad = False print(keyword, "->", name, "frozen") print("=" * 80) for name, param in model.module.named_parameters(): print(param.requires_grad, "\t", name) print("=" * 80) for name, param in model.module.named_parameters(): print(param.requires_grad, "\t", name) optimizer = torch.optim.SGD(policies, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.resume: if os.path.isfile(args.resume): print(("=> loading checkpoint '{}'".format(args.resume))) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print(("=> loaded checkpoint '{}' (epoch {})" .format(args.evaluate, checkpoint['epoch']))) else: print(("=> no checkpoint found at '{}'".format(args.resume))) if args.tune_from: print(("=> fine-tuning from '{}'".format(args.tune_from))) sd = torch.load(args.tune_from) sd = sd['state_dict'] model_dict = model.state_dict() replace_dict = [] for k, v in sd.items(): if k not in model_dict and k.replace('.net', '') in model_dict: print('=> Load after remove .net: ', k) replace_dict.append((k, k.replace('.net', ''))) for k, v in model_dict.items(): if k not in sd and k.replace('.net', '') in sd: print('=> Load after adding .net: ', k) replace_dict.append((k.replace('.net', ''), k)) for k, k_new in replace_dict: sd[k_new] = sd.pop(k) keys1 = set(list(sd.keys())) keys2 = set(list(model_dict.keys())) set_diff = (keys1 - keys2) | (keys2 - keys1) print('#### Notice: keys that failed to load: {}'.format(set_diff)) if args.dataset not in args.tune_from: # new dataset print('=> New dataset, do not load fc weights') sd = {k: v for k, v in sd.items() if 'fc' not in k} model_dict.update(sd) model.load_state_dict(model_dict) # TODO(yue) ada_model loading process if args.ada_reso_skip: if test_mode: print("Test mode load from pretrained model") the_model_path = args.test_from if ".pth.tar" not in the_model_path: the_model_path = ospj(the_model_path, "models", "ckpt.best.pth.tar") model_dict = model.state_dict() sd = load_to_sd(model_dict, the_model_path, "foo", "bar", -1, apple_to_apple=True) model_dict.update(sd) model.load_state_dict(model_dict) elif args.base_pretrained_from != "": print("Adaptively load from pretrained whole") model_dict = model.state_dict() sd = load_to_sd(model_dict, args.base_pretrained_from, "foo", "bar", -1, apple_to_apple=True) model_dict.update(sd) model.load_state_dict(model_dict) elif len(args.model_paths) != 0: print("Adaptively load from model_path_list") model_dict = model.state_dict() # TODO(yue) policy net sd = load_to_sd(model_dict, args.policy_path, "lite_backbone", "lite_fc", args.reso_list[args.policy_input_offset]) model_dict.update(sd) # TODO(yue) backbones for i, tmp_path in enumerate(args.model_paths): base_model_index = i new_i = i sd = load_to_sd(model_dict, tmp_path, "base_model_list.%d" % base_model_index, "new_fc_list.%d" % new_i, args.reso_list[i]) model_dict.update(sd) model.load_state_dict(model_dict) else: if test_mode: the_model_path = args.test_from if ".pth.tar" not in the_model_path: the_model_path = ospj(the_model_path, "models", "ckpt.best.pth.tar") model_dict = model.state_dict() sd = load_to_sd(model_dict, the_model_path, "foo", "bar", -1, apple_to_apple=True) model_dict.update(sd) model.load_state_dict(model_dict) if args.ada_reso_skip == False and args.base_pretrained_from != "": print("Baseline: load from pretrained model") model_dict = model.state_dict() sd = load_to_sd(model_dict, args.base_pretrained_from, "base_model", "new_fc", 224) if args.ignore_new_fc_weight: print("@ IGNORE NEW FC WEIGHT !!!") del sd["module.new_fc.weight"] del sd["module.new_fc.bias"] model_dict.update(sd) model.load_state_dict(model_dict) cudnn.benchmark = True # Data loading code normalize = GroupNormalize(input_mean, input_std) data_length = 1 train_loader = torch.utils.data.DataLoader( TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments, image_tmpl=prefix, transform=torchvision.transforms.Compose([ train_augmentation, Stack(roll=False), ToTorchFormatTensor(div=True), normalize, ]), dense_sample=args.dense_sample, dataset=args.dataset, partial_fcvid_eval=args.partial_fcvid_eval, partial_ratio=args.partial_ratio, ada_reso_skip=args.ada_reso_skip, reso_list=args.reso_list, random_crop=args.random_crop, center_crop=args.center_crop, ada_crop_list=args.ada_crop_list, rescale_to=args.rescale_to, policy_input_offset=args.policy_input_offset, save_meta=args.save_meta), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) # prevent something not % n_GPU val_loader = torch.utils.data.DataLoader( TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments, image_tmpl=prefix, random_shift=False, transform=torchvision.transforms.Compose([ GroupScale(int(scale_size)), GroupCenterCrop(crop_size), Stack(roll=False), ToTorchFormatTensor(div=True), normalize, ]), dense_sample=args.dense_sample, dataset=args.dataset, partial_fcvid_eval=args.partial_fcvid_eval, partial_ratio=args.partial_ratio, ada_reso_skip=args.ada_reso_skip, reso_list=args.reso_list, random_crop=args.random_crop, center_crop=args.center_crop, ada_crop_list=args.ada_crop_list, rescale_to=args.rescale_to, policy_input_offset=args.policy_input_offset, save_meta=args.save_meta ), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) # define loss function (criterion) and optimizer criterion = torch.nn.CrossEntropyLoss().cuda() if args.evaluate: validate(val_loader, model, criterion, 0) return if not test_mode: exp_full_path = setup_log_directory(logger, args.log_dir, args.exp_header) else: exp_full_path = None if not args.ablation: if not test_mode: with open(os.path.join(exp_full_path, 'args.txt'), 'w') as f: f.write(str(args)) tf_writer = SummaryWriter(log_dir=exp_full_path) else: tf_writer = None else: tf_writer = None # TODO(yue) map_record = Recorder() mmap_record = Recorder() prec_record = Recorder() best_train_usage_str = None best_val_usage_str = None wandb.watch(model) for epoch in range(args.start_epoch, args.epochs): # train for one epoch if not args.skip_training: set_random_seed(args.random_seed + epoch) adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps) train_usage_str = train(train_loader, model, criterion, optimizer, epoch, logger, exp_full_path, tf_writer) else: train_usage_str = "No training usage stats (Eval Mode)" # evaluate on validation set if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: set_random_seed(args.random_seed) mAP, mmAP, prec1, val_usage_str, val_gflops = validate(val_loader, model, criterion, epoch, logger, exp_full_path, tf_writer) # remember best prec@1 and save checkpoint map_record.update(mAP) mmap_record.update(mmAP) prec_record.update(prec1) if mmap_record.is_current_best(): best_train_usage_str = train_usage_str best_val_usage_str = val_usage_str print('Best mAP: %.3f (epoch=%d)\t\tBest mmAP: %.3f(epoch=%d)\t\tBest Prec@1: %.3f (epoch=%d)' % ( map_record.best_val, map_record.best_at, mmap_record.best_val, mmap_record.best_at, prec_record.best_val, prec_record.best_at)) if args.skip_training: break if (not args.ablation) and (not test_mode): tf_writer.add_scalar('acc/test_top1_best', prec_record.best_val, epoch) save_checkpoint({ 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_prec1': prec_record.best_val, }, mmap_record.is_current_best(), exp_full_path) if use_ada_framework and not test_mode: print("Best train usage:") print(best_train_usage_str) print() print("Best val usage:") print(best_val_usage_str) print("Finished in %.4f seconds\n" % (time.time() - t_start))
def main(): global args, best_prec1 args = parser.parse_args() print("------------------------------------") print("Environment Versions:") print("- Python: {}".format(sys.version)) print("- PyTorch: {}".format(torch.__version__)) print("- TorchVison: {}".format(torchvision.__version__)) args_dict = args.__dict__ print("------------------------------------") print(args.arch + " Configurations:") for key in args_dict.keys(): print("- {}: {}".format(key, args_dict[key])) print("------------------------------------") print(args.mode) if args.dataset == 'ucf101': num_class = 101 rgb_read_format = "{:05d}.jpg" elif args.dataset == 'hmdb51': num_class = 51 rgb_read_format = "{:05d}.jpg" elif args.dataset == 'kinetics': num_class = 400 rgb_read_format = "{:05d}.jpg" elif args.dataset == 'something': num_class = 174 rgb_read_format = "{:05d}.jpg" elif args.dataset == 'somethingv2': num_class = 174 rgb_read_format = "img_{:05d}.jpg" elif args.dataset == 'NTU_RGBD': num_class = 120 rgb_read_format = "{:05d}.jpg" elif args.dataset == 'tinykinetics': num_class = 150 rgb_read_format = "{:05d}.jpg" else: raise ValueError('Unknown dataset ' + args.dataset) model = TSN(num_class, args.num_segments, args.modality, base_model=args.arch, consensus_type=args.consensus_type, dropout=args.dropout, partial_bn=not args.no_partialbn, non_local=args.non_local) crop_size = model.crop_size scale_size = model.scale_size input_mean = model.input_mean input_std = model.input_std # Optimizer s also support specifying per-parameter options. # To do this, pass in an iterable of dict s. # Each of them will define a separate parameter group, # and should contain a params key, containing a list of parameters belonging to it. # Other keys should match the keyword arguments accepted by the optimizers, # and will be used as optimization options for this group. policies = model.get_optim_policies(args.dataset) train_augmentation = model.get_augmentation() model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda() model_dict = model.state_dict() if args.arch == "resnet50": new_state_dict = {} #model_dict div = False roll = True elif args.arch == "resnet34": pretrained_dict = {} new_state_dict = {} #model_dict for k, v in model_dict.items(): if ('fc' not in k): new_state_dict.update({k: v}) div = False roll = True elif (args.arch[:3] == "TCM"): pretrained_dict = {} new_state_dict = {} #model_dict for k, v in model_dict.items(): if ('fc' not in k): new_state_dict.update({k: v}) div = True roll = False if args.resume: if os.path.isfile(args.resume): print(("=> loading checkpoint '{}'".format(args.resume))) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print(("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch']))) else: print(("=> no checkpoint found at '{}'".format(args.resume))) cudnn.benchmark = True # Data loading code if args.modality != 'RGBDiff': normalize = GroupNormalize(input_mean, input_std) else: normalize = IdentityTransform() if args.modality == 'RGB': data_length = 1 elif args.modality in ['Flow', 'RGBDiff']: data_length = 1 train_loader = torch.utils.data.DataLoader( TSNDataSet( "", args.train_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, mode=args.mode, image_tmpl=args.rgb_prefix + rgb_read_format if args.modality in ["RGB", "RGBDiff"] else args.flow_prefix + rgb_read_format, img_start_idx=args.img_start_idx, transform=torchvision.transforms.Compose([ GroupScale((240, 320)), # GroupScale(int(scale_size)), train_augmentation, Stack(roll=roll), ToTorchFormatTensor(div=div), normalize, ])), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader( TSNDataSet( "", args.val_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, mode=args.mode, image_tmpl=args.rgb_prefix + rgb_read_format if args.modality in ["RGB", "RGBDiff"] else args.flow_prefix + rgb_read_format, img_start_idx=args.img_start_idx, random_shift=False, transform=torchvision.transforms.Compose([ GroupScale((240, 320)), # GroupScale((224)), # GroupScale(int(scale_size)), GroupCenterCrop(crop_size), Stack(roll=roll), ToTorchFormatTensor(div=div), normalize, ])), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) # define loss function (criterion) and optimizer if args.loss_type == 'nll': criterion = torch.nn.CrossEntropyLoss().cuda() else: raise ValueError("Unknown loss type") for group in policies: print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format( group['name'], len(group['params']), group['lr_mult'], group['decay_mult']))) optimizer = torch.optim.SGD(policies, args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov) output_list = [] if args.evaluate: prec1, score_tensor = validate(val_loader, model, criterion, temperature=100) output_list.append(score_tensor) save_validation_score(output_list, filename='score.pt') print("validation score saved in {}".format('/'.join( (args.val_output_folder, 'score_inf5.pt')))) return for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch, args.lr_steps) # train for one epoch temperature = train(train_loader, model, criterion, optimizer, epoch) # evaluate on validation set if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: prec1, score_tensor = validate(val_loader, model, criterion, temperature=temperature) output_list.append(score_tensor) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) output_best = 'Best Prec@1: %.3f\n' % (best_prec1) print(output_best) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, is_best) # save validation score save_validation_score(output_list) print("validation score saved in {}".format('/'.join( (args.val_output_folder, 'score.pt'))))
raise ValueError( "Only 1, 5, 10 crops are supported while we got {}".format( args.test_crops)) data_loader = torch.utils.data.DataLoader( TSNDataSet( root_path, [test_file] if test_file is not None else val_list, num_segments=this_test_segments, new_length=[new_length], modality=[modality], image_tmpl=prefix, test_mode=True, remove_missing=len(weights_list) == 1, transform=torchvision.transforms.Compose([ cropping, Stack(roll=(this_arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor( div=(this_arch not in ['BNInception', 'InceptionV3'])), GroupNormalize(net.input_mean, net.input_std) if modality != 'RGBDiff' else IdentityTransform(), ]), dense_sample=args.dense_sample, dense_length=args.dense_length, dense_number=args.dense_number, twice_sample=args.twice_sample, random_sample=args.random_sample), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, )
def main(): # options parser = argparse.ArgumentParser(description="TSM testing on the full validation set") parser.add_argument('dataset', type=str) # may contain splits parser.add_argument('--weights', type=str, default=None) parser.add_argument('--test_segments', type=str, default=25) parser.add_argument('--dense_sample', default=False, action="store_true", help='use dense sample as I3D') parser.add_argument('--twice_sample', default=False, action="store_true", help='use twice sample for ensemble') parser.add_argument('--full_res', default=False, action="store_true", help='use full resolution 256x256 for test as in Non-local I3D') parser.add_argument('--test_crops', type=int, default=1) parser.add_argument('--coeff', type=str, default=None) parser.add_argument('--batch_size', type=int, default=1) parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', help='number of data loading workers (default: 8)') # for true test parser.add_argument('--test_list', type=str, default=None) parser.add_argument('--csv_file', type=str, default=None) parser.add_argument('--softmax', default=False, action="store_true", help='use softmax') parser.add_argument('--max_num', type=int, default=-1) parser.add_argument('--input_size', type=int, default=224) parser.add_argument('--crop_fusion_type', type=str, default='avg') parser.add_argument('--gpus', nargs='+', type=int, default=None) parser.add_argument('--img_feature_dim',type=int, default=256) parser.add_argument('--num_set_segments',type=int, default=1,help='TODO: select multiply set of n-frames from a video') parser.add_argument('--pretrain', type=str, default='imagenet') args = parser.parse_args() def accuracy(output, target, topk=(1,)): """Computes the precision@k for the specified values of k""" maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0) res.append(correct_k.mul_(100.0 / batch_size)) return res def parse_shift_option_from_log_name(log_name): if 'shift' in log_name: strings = log_name.split('_') for i, s in enumerate(strings): if 'shift' in s: break return True, int(strings[i].replace('shift', '')), strings[i + 1] else: return False, None, None weights_list = args.weights.split(',') test_segments_list = [int(s) for s in args.test_segments.split(',')] assert len(weights_list) == len(test_segments_list) if args.coeff is None: coeff_list = [1] * len(weights_list) else: coeff_list = [float(c) for c in args.coeff.split(',')] if args.test_list is not None: test_file_list = args.test_list.split(',') else: test_file_list = [None] * len(weights_list) data_iter_list = [] net_list = [] modality_list = [] total_num = None for this_weights, this_test_segments, test_file in zip(weights_list, test_segments_list, test_file_list): is_shift, shift_div, shift_place = parse_shift_option_from_log_name(this_weights) if 'RGB' in this_weights: modality = 'RGB' else: modality = 'Flow' this_arch = this_weights.split('TSM_')[1].split('_')[2] modality_list.append(modality) num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset(args.dataset, modality) print('=> shift: {}, shift_div: {}, shift_place: {}'.format(is_shift, shift_div, shift_place)) net = TSN(num_class, this_test_segments if is_shift else 1, modality, base_model=this_arch, consensus_type=args.crop_fusion_type, img_feature_dim=args.img_feature_dim, pretrain=args.pretrain, is_shift=is_shift, shift_div=shift_div, shift_place=shift_place, non_local='_nl' in this_weights, ) if 'tpool' in this_weights: from ops.temporal_shift import make_temporal_pool make_temporal_pool(net.base_model, this_test_segments) # since DataParallel checkpoint = torch.load(this_weights) checkpoint = checkpoint['state_dict'] # base_dict = {('base_model.' + k).replace('base_model.fc', 'new_fc'): v for k, v in list(checkpoint.items())} base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())} replace_dict = {'base_model.classifier.weight': 'new_fc.weight', 'base_model.classifier.bias': 'new_fc.bias', } for k, v in replace_dict.items(): if k in base_dict: base_dict[v] = base_dict.pop(k) net.load_state_dict(base_dict) input_size = net.scale_size if args.full_res else net.input_size if args.test_crops == 1: cropping = torchvision.transforms.Compose([ GroupScale(net.scale_size), GroupCenterCrop(input_size), ]) elif args.test_crops == 3: # do not flip, so only 5 crops cropping = torchvision.transforms.Compose([ GroupFullResSample(input_size, net.scale_size, flip=False) ]) elif args.test_crops == 5: # do not flip, so only 5 crops cropping = torchvision.transforms.Compose([ GroupOverSample(input_size, net.scale_size, flip=False) ]) elif args.test_crops == 10: cropping = torchvision.transforms.Compose([ GroupOverSample(input_size, net.scale_size) ]) else: raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(args.test_crops)) data_loader = torch.utils.data.DataLoader( TSNDataSet(root_path, test_file if test_file is not None else val_list, num_segments=this_test_segments, new_length=1 if modality == "RGB" else 5, modality=modality, image_tmpl=prefix, test_mode=True, remove_missing=len(weights_list) == 1, transform=torchvision.transforms.Compose([ cropping, Stack(roll=(this_arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor(div=(this_arch not in ['BNInception', 'InceptionV3'])), GroupNormalize(net.input_mean, net.input_std), ]), dense_sample=args.dense_sample, twice_sample=args.twice_sample), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, ) if args.gpus is not None: devices = [args.gpus[i] for i in range(args.workers)] else: devices = list(range(args.workers)) net = torch.nn.DataParallel(net.cuda()) net.eval() data_gen = enumerate(data_loader) if total_num is None: total_num = len(data_loader.dataset) else: assert total_num == len(data_loader.dataset) data_iter_list.append(data_gen) net_list.append(net) output = [] def eval_video(video_data, net, this_test_segments, modality): net.eval() with torch.no_grad(): i, data, label = video_data batch_size = label.numel() num_crop = args.test_crops if args.dense_sample: num_crop *= 10 # 10 clips for testing when using dense sample if args.twice_sample: num_crop *= 2 if modality == 'RGB': length = 3 elif modality == 'Flow': length = 10 elif modality == 'RGBDiff': length = 18 else: raise ValueError("Unknown modality "+ modality) data_in = data.view(-1, length, data.size(2), data.size(3)) if is_shift: data_in = data_in.view(batch_size * num_crop, this_test_segments, length, data_in.size(2), data_in.size(3)) rst = net(data_in) rst = rst.reshape(batch_size, num_crop, -1).mean(1) if args.softmax: # take the softmax to normalize the output to probability rst = F.softmax(rst, dim=1) rst = rst.data.cpu().numpy().copy() if net.module.is_shift: rst = rst.reshape(batch_size, num_class) else: rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class)) return i, rst, label proc_start_time = time.time() max_num = args.max_num if args.max_num > 0 else total_num top1 = AverageMeter() top5 = AverageMeter() for i, data_label_pairs in enumerate(zip(*data_iter_list)): with torch.no_grad(): if i >= max_num: break this_rst_list = [] this_label = None for n_seg, (_, (data, label)), net, modality in zip(test_segments_list, data_label_pairs, net_list, modality_list): rst = eval_video((i, data, label), net, n_seg, modality) this_rst_list.append(rst[1]) this_label = label assert len(this_rst_list) == len(coeff_list) for i_coeff in range(len(this_rst_list)): this_rst_list[i_coeff] *= coeff_list[i_coeff] ensembled_predict = sum(this_rst_list) / len(this_rst_list) for p, g in zip(ensembled_predict, this_label.cpu().numpy()): output.append([p[None, ...], g]) cnt_time = time.time() - proc_start_time prec1, prec5 = accuracy(torch.from_numpy(ensembled_predict), this_label, topk=(1, 5)) top1.update(prec1.item(), this_label.numel()) top5.update(prec5.item(), this_label.numel()) if i % 20 == 0: print('video {} done, total {}/{}, average {:.3f} sec/video, ' 'moving Prec@1 {:.3f} Prec@5 {:.3f}'.format(i * args.batch_size, i * args.batch_size, total_num, float(cnt_time) / (i+1) / args.batch_size, top1.avg, top5.avg)) video_pred = [np.argmax(x[0]) for x in output] video_pred_top5 = [np.argsort(np.mean(x[0], axis=0).reshape(-1))[::-1][:5] for x in output] video_labels = [x[1] for x in output] if args.csv_file is not None: print('=> Writing result to csv file: {}'.format(args.csv_file)) with open(test_file_list[0].replace('test_videofolder.txt', 'category.txt')) as f: categories = f.readlines() categories = [f.strip() for f in categories] with open(test_file_list[0]) as f: vid_names = f.readlines() vid_names = [n.split(' ')[0] for n in vid_names] print(vid_names) assert len(vid_names) == len(video_pred) if args.dataset != 'somethingv2': # only output top1 with open(args.csv_file, 'w') as f: for n, pred in zip(vid_names, video_pred): f.write('{};{}\n'.format(n, categories[pred])) else: with open(args.csv_file, 'w') as f: for n, pred5 in zip(vid_names, video_pred_top5): fill = [n] for p in list(pred5): fill.append(p) f.write('{};{};{};{};{};{}\n'.format(*fill)) cf = confusion_matrix(video_labels, video_pred).astype(float) np.save('cm.npy', cf) cls_cnt = cf.sum(axis=1) cls_hit = np.diag(cf) cls_acc = cls_hit / cls_cnt print(cls_acc) upper = np.mean(np.max(cf, axis=1) / cls_cnt) print('upper bound: {}'.format(upper)) print('-----Evaluation is finished------') print('Class Accuracy {:.02f}%'.format(np.mean(cls_acc) * 100)) print('Overall Prec@1 {:.02f}% Prec@5 {:.02f}%'.format(top1.avg, top5.avg))
def main(): global args, best_prec1 args = parser.parse_args() num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset, args.modality) model = TSN(num_class, args.num_segments, args.modality, base_model=args.arch, consensus_type=args.consensus_type, dropout=args.dropout, img_feature_dim=args.img_feature_dim, partial_bn=not args.no_partialbn, pretrain=args.pretrain, is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place, fc_lr5=not (args.tune_from and args.dataset in args.tune_from), temporal_pool=args.temporal_pool, non_local=args.non_local, cca3d = args.cca3d ) crop_size = model.crop_size scale_size = model.scale_size input_mean = model.input_mean input_std = model.input_std policies = model.get_optim_policies() train_augmentation = model.get_augmentation(flip=False if 'something' in args.dataset or 'jester' in args.dataset else True) model = model.cuda() if args.resume: if args.temporal_pool: # early temporal pool so that we can load the state_dict make_temporal_pool(model.module.base_model, args.num_segments) if os.path.isfile(args.resume): print(("=> loading checkpoint '{}'".format(args.resume))) checkpoint = torch.load(args.resume,map_location='cuda:0') parallel_state_dict = checkpoint['state_dict'] cpu_state_dict={} for k,v in parallel_state_dict.items(): cpu_state_dict[k[len('module.'):]] = v model.load_state_dict(cpu_state_dict) print(("=> loaded checkpoint '{}' (epoch {})".format(args.evaluate, checkpoint['epoch']))) else: print(("=> no checkpoint found at '{}'".format(args.resume))) # Data loading code if args.modality != 'RGBDiff': normalize = GroupNormalize(input_mean, input_std) else: normalize = IdentityTransform() if args.modality == 'RGB': data_length = 1 elif args.modality in ['Flow', 'RGBDiff']: data_length = 5 preprocess=torchvision.transforms.Compose([ GroupScale(int(scale_size)), GroupCenterCrop(crop_size), Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize, ]) val_loader = torch.utils.data.DataLoader( TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, random_shift=False, transform = preprocess, dense_sample=args.dense_sample), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=False) norm_param = (input_mean, input_std) cam_process(val_loader, model,norm_param)
def main_worker(gpu, ngpus_per_node, args): global best_acc1 args.gpu = gpu if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) if args.distributed: if args.dist_url == "env://" and args.rank == -1: args.rank = int(os.environ["RANK"]) if args.multiprocessing_distributed: # For multiprocessing distributed training, rank needs to be the # global rank among all the processes rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=rank) else: rank = 0 # create model num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset, args.modality) full_arch_name = args.arch if args.shift: full_arch_name += '_shift{}_{}'.format(args.shift_div, args.shift_place) if args.temporal_pool: full_arch_name += '_tpool' args.store_name = '_'.join( ['TSM', args.dataset, args.modality, full_arch_name, args.consensus_type, 'segment%d' % args.num_segments, 'e{}'.format(args.epochs)]) if args.pretrain != 'imagenet': args.store_name += '_{}'.format(args.pretrain) if args.lr_type != 'step': args.store_name += '_{}'.format(args.lr_type) if args.dense_sample: args.store_name += '_dense' if args.non_local > 0: args.store_name += '_nl' args.store_name += '_lr{}'.format(args.lr) args.store_name += '_wd{:.1e}'.format(args.weight_decay) args.store_name += '_do{}'.format(args.dropout) if args.suffix is not None: args.store_name += '_{}'.format(args.suffix) print('storing name: ' + args.store_name) check_rootfolders(args, rank) model = TSN(num_class, args.num_segments, args.modality, base_model=args.arch, consensus_type=args.consensus_type, dropout=args.dropout, img_feature_dim=args.img_feature_dim, partial_bn=not args.no_partialbn, pretrain=args.pretrain, is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place, fc_lr5=not (args.tune_from and args.dataset in args.tune_from), temporal_pool=args.temporal_pool, non_local=args.non_local) # first synchronization of initial weights # sync_initial_weights(model, args.rank, args.world_size) model = nn.SyncBatchNorm.convert_sync_batchnorm(model) if rank == 0: print(model) crop_size = model.crop_size scale_size = model.scale_size input_mean = model.input_mean input_std = model.input_std policies = model.get_optim_policies() train_augmentation = model.get_augmentation(flip=False if 'something' in args.dataset or 'jester' in args.dataset else True) if args.distributed: # For multiprocessing distributed, DistributedDataParallel constructor # should always set the single device scope, otherwise, # DistributedDataParallel will use all available devices. if args.gpu is not None: torch.cuda.set_device(args.gpu) model.cuda(args.gpu) # When using a single GPU per process and per # DistributedDataParallel, we need to divide the batch size # ourselves based on the total number of GPUs we have on a node args.batch_size = int(args.batch_size / ngpus_per_node) args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) else: model.cuda() # DistributedDataParallel will divide and allocate batch_size to all # available GPUs if device_ids are not set model = torch.nn.parallel.DistributedDataParallel(model) elif args.gpu is not None: torch.cuda.set_device(args.gpu) model = model.cuda(args.gpu) else: # DataParallel will divide and allocate batch_size to all available GPUs if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): model.features = torch.nn.DataParallel(model.features) model.cuda() else: model = torch.nn.DataParallel(model).cuda() # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda(args.gpu) optimizer = torch.optim.SGD(policies, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) if args.gpu is None: checkpoint = torch.load(args.resume) else: # Map model to be loaded to specified single gpu. loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.resume, map_location=loc) if args.start_epoch == 1: args.start_epoch = checkpoint['epoch'] + 1 best_acc1 = checkpoint['best_acc1'] # if args.gpu is not None: # # best_acc1 may be from a checkpoint from a different GPU # best_acc1 = best_acc1.to(args.gpu) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True # Data loading code if args.modality != 'RGBDiff': normalize = GroupNormalize(input_mean, input_std) else: normalize = IdentityTransform() if args.modality == 'RGB': data_length = 1 elif args.modality in ['Flow', 'RGBDiff']: data_length = 5 train_dataset = TSNDataSet(args.dataset, args.root_path, args.train_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, transform=torchvision.transforms.Compose([train_augmentation, Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize]), dense_sample=args.dense_sample) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, drop_last=True, sampler=train_sampler) val_loader = torch.utils.data.DataLoader( TSNDataSet(args.dataset, args.root_path, args.val_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, random_shift=False, transform=torchvision.transforms.Compose([ GroupScale(int(scale_size)), GroupCenterCrop(crop_size), Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize, ]), dense_sample=args.dense_sample), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) if args.evaluate: validate(val_loader, model, criterion, args) return log_training = open(os.path.join(args.root_model, args.store_name, 'log.csv'), 'w') with open(os.path.join(args.root_model, args.store_name, 'args.txt'), 'w') as f: f.write(str(args)) tf_writer = SummaryWriter(log_dir=os.path.join(args.root_model, args.store_name)) for epoch in range(args.start_epoch, args.epochs+1): if args.distributed: train_sampler.set_epoch(epoch) # adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps, args) # train for one epoch train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer, args, rank) if rank % ngpus_per_node == 0: save_checkpoint({ 'epoch': epoch, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_acc1': best_acc1, 'optimizer' : optimizer.state_dict(), }, False, args, rank) if epoch % 5 == 0 and rank % ngpus_per_node == 0: save_checkpoint({ 'epoch': epoch, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_acc1': best_acc1, 'optimizer': optimizer.state_dict(), }, False, args, rank, e=epoch) # evaluate on validation set is_best = False if epoch % args.eval_freq == 0 or epoch == args.epochs: acc1 = validate(val_loader, model, criterion, epoch, args, rank, log_training, tf_writer) # remember best acc@1 and save checkpoint is_best = acc1 > best_acc1 best_acc1 = max(acc1, best_acc1) if not args.multiprocessing_distributed or (args.multiprocessing_distributed and rank % ngpus_per_node == 0): save_checkpoint({ 'epoch': epoch, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_acc1': best_acc1, 'optimizer' : optimizer.state_dict(), }, is_best, args, rank)
[GroupOverSample(input_size, net.scale_size)]) else: raise ValueError( "Only 1, 5, 10 crops are supported while we got {}".format( args.test_crops)) data_loader = torch.utils.data.DataLoader( TSNDataSet( val_list, num_segments=args.test_segments, new_length=1 if args.modality == "RGB" else 5, modality=args.modality, image_tmpl=prefix, test_mode=True, remove_missing=True, multi_clip_test=args.multi_clip_test, transform=torchvision.transforms.Compose([ cropping, Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor( div=(args.arch not in ['BNInception', 'InceptionV3'])), GroupNormalize(net.input_mean, net.input_std), ]), dense_sample=args.dense_sample, ), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, ) if args.gpus is not None:
def main(): global args, best_prec1, least_loss least_loss = 1000 args = parser.parse_args() if os.path.exists(os.path.join(args.root_log, "error.log")): os.remove(os.path.join(args.root_log, "error.log")) logging.basicConfig( level=logging.DEBUG, filename=os.path.join(args.root_log, "error.log"), filemode='a', format= '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' ) # log_handler = open(os.path.join(args.root_log,"error.log"),"w") # sys.stdout = log_handler if args.root_path: num_class, args.train_list, args.val_list, _, prefix = dataset_config.return_dataset( args.dataset, args.modality) args.train_list = os.path.join(args.root_log, "kf1_train_anno_lijun_iod.json") args.test_list = os.path.join(args.root_log, "kf1_test_anno_lijun_iod.json") else: num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset( args.dataset, args.modality) full_arch_name = args.arch if args.shift: full_arch_name += '_shift{}_{}'.format(args.shift_div, args.shift_place) if args.temporal_pool: full_arch_name += '_tpool' args.store_name = '_'.join([ 'TSA', args.dataset, args.modality, full_arch_name, args.consensus_type, 'segment%d' % args.num_segments, 'e{}'.format(args.epochs) ]) # if args.pretrain != 'imagenet': # args.store_name += '_{}'.format(args.pretrain) if args.lr_type != 'step': args.store_name += '_{}'.format(args.lr_type) if args.dense_sample: args.store_name += '_dense' if args.non_local > 0: args.store_name += '_nl' if args.suffix is not None: args.store_name += '_{}'.format(args.suffix) print('storing name: ' + args.store_name) check_rootfolders() model = TSN(num_class, args.num_segments, args.modality, base_model=args.arch, consensus_type=args.consensus_type, dropout=args.dropout, img_feature_dim=args.img_feature_dim, partial_bn=not args.no_partialbn, pretrain=args.pretrain, is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place, fc_lr5=not (args.tune_from and args.dataset in args.tune_from), temporal_pool=args.temporal_pool, non_local=args.non_local, is_TSA=args.tsa, is_sTSA=args.stsa, is_tTSA=args.ttsa, shift_diff=args.shift_diff, shift_groups=args.shift_groups, is_ME=args.me, is_3D=args.is_3D, cfg_file=args.cfg_file) crop_size = model.crop_size scale_size = model.scale_size input_mean = model.input_mean input_std = model.input_std # policies = model.get_optim_policies() train_augmentation = model.get_augmentation( flip=False if 'something' in args.dataset or 'jester' in args.dataset else True) model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda() if args.optimizer == "sgd": if args.lr_scheduler: optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) else: optimizer = torch.optim.SGD(policies, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) elif args.optimizer == "adam": params = get_vmz_fine_tuning_parameters(model, args.vmz_tune_last_k_layer) optimizer = torch.optim.Adam(params, args.lr, weight_decay=args.weight_decay) else: raise RuntimeError("not supported optimizer") if args.lr_scheduler: scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, args.lr_steps, args.lr_scheduler_gamma) if args.resume: if args.temporal_pool: # early temporal pool so that we can load the state_dict make_temporal_pool(model.module.base_model, args.num_segments) if os.path.isfile(args.resume): print(("=> loading checkpoint '{}'".format(args.resume))) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) # if args.lr_scheduler: # scheduler.load_state_dict(checkpoint["lr_scheduler"]) print(("=> loaded checkpoint '{}' (epoch {})".format( args.evaluate, checkpoint['epoch']))) logging.info(("=> loaded checkpoint '{}' (epoch {})".format( args.evaluate, checkpoint['epoch']))) else: print(("=> no checkpoint found at '{}'".format(args.resume))) logging.error( ("=> no checkpoint found at '{}'".format(args.resume))) if args.tune_from: print(("=> fine-tuning from '{}'".format(args.tune_from))) sd = torch.load(args.tune_from) sd = sd['state_dict'] model_dict = model.state_dict() replace_dict = [] for k, v in sd.items(): if k not in model_dict and k.replace('.net', '') in model_dict: print('=> Load after remove .net: ', k) replace_dict.append((k, k.replace('.net', ''))) for k, v in model_dict.items(): if k not in sd and k.replace('.net', '') in sd: print('=> Load after adding .net: ', k) replace_dict.append((k.replace('.net', ''), k)) for k, k_new in replace_dict: sd[k_new] = sd.pop(k) keys1 = set(list(sd.keys())) keys2 = set(list(model_dict.keys())) set_diff = (keys1 - keys2) | (keys2 - keys1) print('#### Notice: keys that failed to load: {}'.format(set_diff)) # sd = {k:v for k, v in sd.items() if k in keys2} sd = {k: v for k, v in sd.items() if k in keys2} if args.dataset not in args.tune_from: # new dataset print('=> New dataset, do not load fc weights') sd = { k: v for k, v in sd.items() if 'fc' not in k and "projection" not in k } if args.modality == 'Flow' and 'Flow' not in args.tune_from: sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k} model_dict.update(sd) model.load_state_dict(model_dict) if args.temporal_pool and not args.resume: make_temporal_pool(model.module.base_model, args.num_segments) cudnn.benchmark = True # Data loading code if args.modality != 'RGBDiff': normalize = GroupNormalize(input_mean, input_std) else: normalize = IdentityTransform() if args.modality in ['RGB', "PoseAction"]: data_length = 1 elif args.modality in ['Flow', 'RGBDiff']: data_length = 5 if not args.shuffle: train_loader = torch.utils.data.DataLoader( TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, transform=torchvision.transforms.Compose([ train_augmentation, Stack(roll=(args.arch in ['BNInception', 'InceptionV3']), inc_dim=(args.arch in ["R2plus1D", "X3D"])), ToTorchFormatTensor( div=(args.arch not in ['BNInception', 'InceptionV3']), inc_dim=(args.arch in ["R2plus1D", "X3D"])), normalize, ]), dense_sample=args.dense_sample, all_sample=args.all_sample), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) # prevent something not % n_GPU val_loader = torch.utils.data.DataLoader(TSNDataSet( args.root_path, args.val_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, random_shift=False, transform=torchvision.transforms.Compose([ GroupScale(scale_size), GroupCenterCrop(crop_size), Stack(roll=(args.arch in ['BNInception', 'InceptionV3']), inc_dim=(args.arch in ["R2plus1D", "X3D"])), ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3']), inc_dim=(args.arch in ["R2plus1D", "X3D"])), normalize, ]), dense_sample=args.dense_sample, all_sample=args.all_sample), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) # for group in policies: # print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format( # group['name'], len(group['params']), group['lr_mult'], group['decay_mult']))) if args.evaluate: if args.loss_type == 'nll': criterion = torch.nn.CrossEntropyLoss().cuda() elif args.loss_type == "bce": criterion = torch.nn.BCEWithLogitsLoss().cuda() elif args.loss_type == "wbce": class_weight, pos_weight = prep_weight(args.train_list) criterion = WeightedBCEWithLogitsLoss(class_weight, pos_weight) else: raise ValueError("Unknown loss type") val_loader = torch.utils.data.DataLoader(TSNDataSet( args.root_path, args.val_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, random_shift=False, transform=torchvision.transforms.Compose([ GroupScale(scale_size), GroupCenterCrop(crop_size), Stack(roll=(args.arch in ['BNInception', 'InceptionV3']), inc_dim=(args.arch in ["R2plus1D", "X3D"])), ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3']), inc_dim=(args.arch in ["R2plus1D", "X3D"])), normalize, ]), dense_sample=args.dense_sample, all_sample=args.all_sample, analysis=True), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) test(val_loader, model, criterion, 0) return log_training = open( os.path.join(args.root_log, args.store_name, 'log.csv'), 'w') with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f: f.write(str(args)) tf_writer = SummaryWriter( log_dir=os.path.join(args.root_log, args.store_name)) print(model) logging.info(model) for epoch in range(args.start_epoch, args.epochs): logging.info("Train Epoch {}/{} starts, estimated time 5832s".format( str(epoch), str(args.epochs))) # update data_loader if args.shuffle: gen_label(args.prop_path, args.label_path, args.trn_name, args.train_list, args.neg_rate, STR=False) gen_label(args.prop_path, args.label_path, args.tst_name, args.val_list, args.test_rate, STR=False) train_loader = torch.utils.data.DataLoader( TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, transform=torchvision.transforms.Compose([ train_augmentation, Stack(roll=(args.arch in ['BNInception', 'InceptionV3']), inc_dim=(args.arch in ["R2plus1D", "X3D"])), ToTorchFormatTensor( div=(args.arch not in ['BNInception', 'InceptionV3']), inc_dim=(args.arch in ["R2plus1D", "X3D"])), normalize, ]), dense_sample=args.dense_sample, all_sample=args.all_sample), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader( TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, random_shift=False, transform=torchvision.transforms.Compose([ GroupScale(scale_size), GroupCenterCrop(crop_size), Stack(roll=(args.arch in ['BNInception', 'InceptionV3']), inc_dim=(args.arch in ["R2plus1D", "X3D"])), ToTorchFormatTensor( div=(args.arch not in ['BNInception', 'InceptionV3']), inc_dim=(args.arch in ["R2plus1D", "X3D"])), normalize, ]), dense_sample=args.dense_sample, all_sample=args.all_sample), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) print(train_loader) # define loss function (criterion) and optimizer if args.loss_type == 'nll': criterion = torch.nn.CrossEntropyLoss().cuda() elif args.loss_type == "bce": criterion = torch.nn.BCEWithLogitsLoss().cuda() elif args.loss_type == "wbce": class_weight, pos_weight = prep_weight(args.train_list) criterion = WeightedBCEWithLogitsLoss(class_weight, pos_weight) else: raise ValueError("Unknown loss type") if not args.lr_scheduler: adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps) train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer) else: train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer) scheduler.step() # train for one epoch # evaluate on validation set if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: logging.info( "Test Epoch {}/{} starts, estimated time 13874s".format( str(epoch // args.eval_freq), str(args.epochs / args.eval_freq))) if args.loss_type == "wbce": # class_weight,pos_weight = prep_weight(args.val_list) criterion = torch.nn.BCEWithLogitsLoss().cuda() lossm = validate(val_loader, model, criterion, epoch, log_training, tf_writer) # remember best prec@1 and save checkpoint is_best = lossm < least_loss least_loss = min(lossm, least_loss) tf_writer.add_scalar('lss/test_top1_best', least_loss, epoch) output_best = 'Best Loss: %.3f\n' % (lossm) logging.info(output_best) log_training.write(output_best + '\n') log_training.flush() if args.lr_scheduler: save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_prec1': least_loss, 'lr_scheduler': scheduler, }, is_best, epoch) else: save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_prec1': least_loss, }, is_best, epoch)
def main(): global args, best_prec1 args = parser.parse_args() num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset( args.dataset, args.modality) full_arch_name = args.arch if args.temporal_pool: full_arch_name += '_tpool' args.store_name = '_'.join([ 'I3D', args.dataset, full_arch_name, 'batch{}'.format(args.batch_size), 'wd{}'.format(args.weight_decay), args.consensus_type, 'segment%d' % args.num_segments, 'e{}'.format(args.epochs), 'dropout{}'.format(args.dropout), args.pretrain, 'lr{}'.format(args.lr), '_warmup{}'.format(args.warmup) ]) if args.lr_type != 'step': args.store_name += '_{}'.format(args.lr_type) else: step_str = [str(int(x)) for x in args.lr_steps] args.store_name += '_step' + '_'.join(step_str) if args.dense_sample: args.store_name += '_dense' if args.spatial_dropout: sigmoid_layer_str = '_'.join(args.sigmoid_layer) args.store_name += '_spatial_drop3d_{}_group{}_layer{}'.format( args.sigmoid_thres, args.sigmoid_group, sigmoid_layer_str) if args.sigmoid_random: args.store_name += '_RandomSigmoid' if args.non_local > 0: args.store_name += '_nl' if args.suffix is not None: args.store_name += '_{}'.format(args.suffix) print('storing name: ' + args.store_name) check_rootfolders() model = i3d(num_class, args.num_segments, base_model=args.arch, consensus_type=args.consensus_type, dropout=args.dropout, spatial_dropout=args.spatial_dropout, sigmoid_thres=args.sigmoid_thres, sigmoid_group=args.sigmoid_group, sigmoid_random=args.sigmoid_random, sigmoid_layer=args.sigmoid_layer, img_feature_dim=args.img_feature_dim, partial_bn=not args.no_partialbn, pretrain=args.pretrain) crop_size = model.crop_size scale_size = model.scale_size input_mean = model.input_mean input_std = model.input_std train_augmentation = model.get_augmentation( flip=False if 'something' in args.dataset or 'jester' in args.dataset else True) model = torch.nn.DataParallel(model, device_ids=list(range(args.gpus))).cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.resume: if args.temporal_pool: # early temporal pool so that we can load the state_dict make_temporal_pool(model.module.base_model, args.num_segments) if os.path.isfile(args.resume): print(("=> loading checkpoint '{}'".format(args.resume))) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print(("=> loaded checkpoint '{}' (epoch {})".format( args.evaluate, checkpoint['epoch']))) else: print(("=> no checkpoint found at '{}'".format(args.resume))) if args.tune_from: print(("=> fine-tuning from '{}'".format(args.tune_from))) sd = torch.load(args.tune_from) sd = sd['state_dict'] model_dict = model.state_dict() replace_dict = [] for k, v in sd.items(): if k not in model_dict and k.replace('.net', '') in model_dict: print('=> Load after remove .net: ', k) replace_dict.append((k, k.replace('.net', ''))) for k, v in model_dict.items(): if k not in sd and k.replace('.net', '') in sd: print('=> Load after adding .net: ', k) replace_dict.append((k.replace('.net', ''), k)) for k, k_new in replace_dict: sd[k_new] = sd.pop(k) keys1 = set(list(sd.keys())) keys2 = set(list(model_dict.keys())) set_diff = (keys1 - keys2) | (keys2 - keys1) print('#### Notice: keys that failed to load: {}'.format(set_diff)) if args.dataset not in args.tune_from: # new dataset print('=> New dataset, do not load fc weights') sd = {k: v for k, v in sd.items() if 'fc' not in k} if args.modality == 'Flow' and 'Flow' not in args.tune_from: sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k} model_dict.update(sd) model.load_state_dict(model_dict) if args.temporal_pool and not args.resume: make_temporal_pool(model.module.base_model, args.num_segments) cudnn.benchmark = True # Data loading code normalize = GroupNormalize(input_mean, input_std) train_loader = torch.utils.data.DataLoader( TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments, image_tmpl=prefix, transform=torchvision.transforms.Compose([ GroupScale((256, 340)), train_augmentation, Stack('3D'), ToTorchFormatTensor(), normalize, ]), dense_sample=args.dense_sample), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) # prevent something not % n_GPU val_loader = torch.utils.data.DataLoader(TSNDataSet( args.root_path, args.val_list, num_segments=args.num_segments, image_tmpl=prefix, random_shift=False, transform=torchvision.transforms.Compose([ GroupScale(int(scale_size)), GroupCenterCrop(crop_size), Stack('3D'), ToTorchFormatTensor(), normalize, ]), dense_sample=args.dense_sample), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) # define loss function (criterion) and optimizer if args.loss_type == 'nll': criterion = torch.nn.BCEWithLogitsLoss().cuda() else: raise ValueError("Unknown loss type") if args.evaluate: validate(val_loader, model, criterion, 0) return log_training = open( os.path.join(args.root_log, args.store_name, 'log.csv'), 'w') with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f: f.write(str(args)) tf_writer = SummaryWriter( log_dir=os.path.join(args.root_log, args.store_name)) for epoch in range(args.start_epoch, args.epochs): # adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps) # train for one epoch train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer, args.warmup, args.lr_type, args.lr_steps) # evaluate on validation set if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: prec1 = validate(val_loader, model, criterion, epoch, log_training, tf_writer) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) tf_writer.add_scalar('acc/test_mAP_best', best_prec1, epoch) output_best = 'Best mAP: %.3f\n' % (best_prec1) print(output_best) log_training.write(output_best + '\n') log_training.flush() save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_prec1': best_prec1, }, is_best)
[GroupOverSample(input_size, net.scale_size)]) else: raise ValueError( "Only 1, 5, 10 crops are supported while we got {}".format( args.test_crops)) data_loader = torch.utils.data.DataLoader( TSNDataSet( root_path, test_file if test_file is not None else val_list, num_segments=this_test_segments, new_length=1 if modality == "RGB" else 5, modality=modality, image_tmpl=prefix, test_mode=True, remove_missing=len(weights_list) == 1, transform=torchvision.transforms.Compose([ cropping, Stack(roll=(this_arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor( div=(this_arch not in ['BNInception', 'InceptionV3'])), GroupNormalize(net.input_mean, net.input_std), ]), dense_sample=args.dense_sample, twice_sample=args.twice_sample), batch_size=args.batch_size, shuffle=False, # num_workers=args.workers, pin_memory=True, ) if args.gpus is not None: devices = [args.gpus[i] for i in range(args.workers)]
def main(): global args, best_prec1 args = parser.parse_args() torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset( args.dataset, args.modality) full_arch_name = args.arch args.store_name = '_'.join([ 'TDN_', args.dataset, args.modality, full_arch_name, args.consensus_type, 'segment%d' % args.num_segments, 'e{}'.format(args.epochs) ]) if args.pretrain != 'imagenet': args.store_name += '_{}'.format(args.pretrain) if args.dense_sample: args.store_name += '_dense' if args.suffix is not None: args.store_name += '_{}'.format(args.suffix) if dist.get_rank() == 0: check_rootfolders() logger = setup_logger(output=os.path.join(args.root_log, args.store_name), distributed_rank=dist.get_rank(), name=f'TDN') logger.info('storing name: ' + args.store_name) model = TSN(num_class, args.num_segments, args.modality, base_model=args.arch, consensus_type=args.consensus_type, dropout=args.dropout, img_feature_dim=args.img_feature_dim, partial_bn=not args.no_partialbn, pretrain=args.pretrain, fc_lr5=(args.tune_from and args.dataset in args.tune_from)) crop_size = model.crop_size scale_size = model.scale_size input_mean = model.input_mean input_std = model.input_std policies = model.get_optim_policies() for group in policies: logger.info( ('[TDN-{}]group: {} has {} params, lr_mult: {}, decay_mult: {}'. format(args.arch, group['name'], len(group['params']), group['lr_mult'], group['decay_mult']))) train_augmentation = model.get_augmentation( flip=False if 'something' in args.dataset else True) cudnn.benchmark = True # Data loading code normalize = GroupNormalize(input_mean, input_std) train_dataset = TSNDataSet( args.dataset, args.root_path, args.train_list, num_segments=args.num_segments, modality=args.modality, image_tmpl=prefix, transform=torchvision.transforms.Compose([ train_augmentation, Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor( div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize, ]), dense_sample=args.dense_sample) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) val_dataset = TSNDataSet( args.dataset, args.root_path, args.val_list, num_segments=args.num_segments, modality=args.modality, image_tmpl=prefix, random_shift=False, transform=torchvision.transforms.Compose([ GroupScale(int(scale_size)), GroupCenterCrop(crop_size), Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor( div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize, ]), dense_sample=args.dense_sample) val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True, sampler=val_sampler, drop_last=True) # define loss function (criterion) and optimizer if args.loss_type == 'nll': criterion = torch.nn.CrossEntropyLoss().cuda() else: raise ValueError("Unknown loss type") optimizer = torch.optim.SGD(policies, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) scheduler = get_scheduler(optimizer, len(train_loader), args) model = DistributedDataParallel(model.cuda(), device_ids=[args.local_rank], broadcast_buffers=True, find_unused_parameters=True) if args.resume: if os.path.isfile(args.resume): logger.info(("=> loading checkpoint '{}'".format(args.resume))) checkpoint = torch.load(args.resume, map_location='cpu') args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) logger.info(("=> loaded checkpoint '{}' (epoch {})".format( args.evaluate, checkpoint['epoch']))) else: logger.info(("=> no checkpoint found at '{}'".format(args.resume))) if args.tune_from: logger.info(("=> fine-tuning from '{}'".format(args.tune_from))) sd = torch.load(args.tune_from) sd = sd['state_dict'] model_dict = model.state_dict() replace_dict = [] for k, v in sd.items(): if k not in model_dict and k.replace('.net', '') in model_dict: logger.info('=> Load after remove .net: ', k) replace_dict.append((k, k.replace('.net', ''))) for k, v in model_dict.items(): if k not in sd and k.replace('.net', '') in sd: logger.info('=> Load after adding .net: ', k) replace_dict.append((k.replace('.net', ''), k)) for k, k_new in replace_dict: sd[k_new] = sd.pop(k) keys1 = set(list(sd.keys())) keys2 = set(list(model_dict.keys())) set_diff = (keys1 - keys2) | (keys2 - keys1) logger.info( '#### Notice: keys that failed to load: {}'.format(set_diff)) if args.dataset not in args.tune_from: # new dataset logger.info('=> New dataset, do not load fc weights') sd = {k: v for k, v in sd.items() if 'fc' not in k} model_dict.update(sd) model.load_state_dict(model_dict) with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f: f.write(str(args)) tf_writer = SummaryWriter( log_dir=os.path.join(args.root_log, args.store_name)) if args.evaluate: logger.info(("===========evaluate===========")) val_loader.sampler.set_epoch(args.start_epoch) prec1, prec5, val_loss = validate(val_loader, model, criterion, logger) if dist.get_rank() == 0: is_best = prec1 > best_prec1 best_prec1 = prec1 logger.info(("Best Prec@1: '{}'".format(best_prec1))) save_epoch = args.start_epoch + 1 save_checkpoint( { 'epoch': args.start_epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'prec1': prec1, 'best_prec1': best_prec1, }, save_epoch, is_best) return for epoch in range(args.start_epoch, args.epochs): train_loader.sampler.set_epoch(epoch) train_loss, train_top1, train_top5 = train(train_loader, model, criterion, optimizer, epoch=epoch, logger=logger, scheduler=scheduler) if dist.get_rank() == 0: tf_writer.add_scalar('loss/train', train_loss, epoch) tf_writer.add_scalar('acc/train_top1', train_top1, epoch) tf_writer.add_scalar('acc/train_top5', train_top5, epoch) tf_writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], epoch) if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: val_loader.sampler.set_epoch(epoch) prec1, prec5, val_loss = validate(val_loader, model, criterion, epoch, logger) if dist.get_rank() == 0: tf_writer.add_scalar('loss/test', val_loss, epoch) tf_writer.add_scalar('acc/test_top1', prec1, epoch) tf_writer.add_scalar('acc/test_top5', prec5, epoch) is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch) logger.info(("Best Prec@1: '{}'".format(best_prec1))) tf_writer.flush() save_epoch = epoch + 1 save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'prec1': prec1, 'best_prec1': best_prec1, }, save_epoch, is_best)
def main(): global args, best_prec1 args = parser.parse_args() num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset( args.dataset, args.modality) full_arch_name = args.arch if args.shift: full_arch_name += '_shift{}_{}'.format(args.shift_div, args.shift_place) if args.temporal_pool: full_arch_name += '_tpool' args.store_name = '_'.join([ 'TSM', args.dataset, args.modality, full_arch_name, args.consensus_type, 'segment%d' % args.num_segments, 'e{}'.format(args.epochs) ]) if args.pretrain != 'imagenet': args.store_name += '_{}'.format(args.pretrain) if args.lr_type != 'step': args.store_name += '_{}'.format(args.lr_type) if args.dense_sample: args.store_name += '_dense' if args.non_local > 0: args.store_name += '_nl' if args.suffix is not None: args.store_name += '_{}'.format(args.suffix) print('storing name: ' + args.store_name) check_rootfolders() model = TSN(num_class, args.num_segments, args.modality, base_model=args.arch, consensus_type=args.consensus_type, dropout=args.dropout, img_feature_dim=args.img_feature_dim, partial_bn=not args.no_partialbn, pretrain=args.pretrain, is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place, fc_lr5=not (args.tune_from and args.dataset in args.tune_from), temporal_pool=args.temporal_pool, non_local=args.non_local) crop_size = model.crop_size scale_size = model.scale_size input_mean = model.input_mean input_std = model.input_std policies = model.get_optim_policies() train_augmentation = model.get_augmentation( flip=False if 'something' in args.dataset or 'jester' in args.dataset else True) model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda() optimizer = torch.optim.SGD(policies, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.resume: if args.temporal_pool: # early temporal pool so that we can load the state_dict make_temporal_pool(model.module.base_model, args.num_segments) if os.path.isfile(args.resume): print(("=> loading checkpoint '{}'".format(args.resume))) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print(("=> loaded checkpoint '{}' (epoch {})".format( args.evaluate, checkpoint['epoch']))) else: print(("=> no checkpoint found at '{}'".format(args.resume))) if args.tune_from: print(("=> fine-tuning from '{}'".format(args.tune_from))) sd = torch.load(args.tune_from) sd = sd['state_dict'] model_dict = model.state_dict() replace_dict = [] for k, v in sd.items(): if k not in model_dict and k.replace('.net', '') in model_dict: print('=> Load after remove .net: ', k) replace_dict.append((k, k.replace('.net', ''))) for k, v in model_dict.items(): if k not in sd and k.replace('.net', '') in sd: print('=> Load after adding .net: ', k) replace_dict.append((k.replace('.net', ''), k)) for k, k_new in replace_dict: sd[k_new] = sd.pop(k) keys1 = set(list(sd.keys())) keys2 = set(list(model_dict.keys())) set_diff = (keys1 - keys2) | (keys2 - keys1) print('#### Notice: keys that failed to load: {}'.format(set_diff)) if args.dataset not in args.tune_from: # new dataset print('=> New dataset, do not load fc weights') sd = {k: v for k, v in sd.items() if 'fc' not in k} if args.modality == 'Flow' and 'Flow' not in args.tune_from: sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k} model_dict.update(sd) model.load_state_dict(model_dict) if args.temporal_pool and not args.resume: make_temporal_pool(model.module.base_model, args.num_segments) cudnn.benchmark = True # Data loading code if args.modality != 'RGBDiff': normalize = GroupNormalize(input_mean, input_std) else: normalize = IdentityTransform() if args.modality == 'RGB': data_length = 1 elif args.modality in ['Flow', 'RGBDiff']: data_length = 5 train_loader = torch.utils.data.DataLoader( TSNDataSet( args.root_path, args.train_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, transform=torchvision.transforms.Compose([ train_augmentation, Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor( div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize, ]), dense_sample=args.dense_sample), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) # prevent something not % n_GPU val_loader = torch.utils.data.DataLoader(TSNDataSet( args.root_path, args.val_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, random_shift=False, transform=torchvision.transforms.Compose([ GroupScale(int(scale_size)), GroupCenterCrop(crop_size), Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor( div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize, ]), dense_sample=args.dense_sample), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) # define loss function (criterion) and optimizer if args.loss_type == 'nll': criterion = torch.nn.CrossEntropyLoss().cuda() else: raise ValueError("Unknown loss type") for group in policies: print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format( group['name'], len(group['params']), group['lr_mult'], group['decay_mult']))) if args.evaluate: validate(val_loader, model, criterion, 0) return log_training = open( os.path.join(args.root_log, args.store_name, 'log.csv'), 'w') with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f: f.write(str(args)) tf_writer = SummaryWriter( log_dir=os.path.join(args.root_log, args.store_name)) for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps) # train for one epoch train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer) # evaluate on validation set if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: prec1 = validate(val_loader, model, criterion, epoch, log_training, tf_writer) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch) output_best = 'Best Prec@1: %.3f\n' % (best_prec1) print(output_best) log_training.write(output_best + '\n') log_training.flush() save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_prec1': best_prec1, }, is_best)
def main(): global args, best_prec1 args = parser.parse_args() #num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset, # args.modality) num_class = 21 args.train_list = "/home/jzwang/code/Video_3D/movienet/data/movie/movie_val_new.txt" args.val_list = "/home/jzwang/code/RGB-FLOW/movie_new/data/datanew/movie_val_new.txt" args.root_path = "" prefix = "frame_{:04d}.jpg" full_arch_name = args.arch if args.shift: full_arch_name += '_shift{}_{}'.format(args.shift_div, args.shift_place) if args.temporal_pool: full_arch_name += '_tpool' args.store_name = '_'.join([ 'TSM', args.dataset, args.modality, full_arch_name, args.consensus_type, 'segment%d' % args.num_segments, 'e{}'.format(args.epochs) ]) if args.pretrain != 'imagenet': args.store_name += '_{}'.format(args.pretrain) if args.lr_type != 'step': args.store_name += '_{}'.format(args.lr_type) if args.dense_sample: args.store_name += '_dense' if args.non_local > 0: args.store_name += '_nl' if args.suffix is not None: args.store_name += '_{}'.format(args.suffix) print('storing name: ' + args.store_name) #check_rootfolders() model = TSN(num_class, args.num_segments, args.modality, base_model=args.arch, consensus_type=args.consensus_type, dropout=args.dropout, img_feature_dim=args.img_feature_dim, partial_bn=not args.no_partialbn, pretrain=args.pretrain, is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place, fc_lr5=not (args.tune_from and args.dataset in args.tune_from), temporal_pool=args.temporal_pool, non_local=args.non_local) crop_size = model.crop_size scale_size = model.scale_size input_mean = model.input_mean input_std = model.input_std policies = model.get_optim_policies() train_augmentation = model.get_augmentation( flip=False if 'something' in args.dataset or 'jester' in args.dataset else True) model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda() optimizer = torch.optim.SGD(policies, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.resume: if args.temporal_pool: # early temporal pool so that we can load the state_dict make_temporal_pool(model.module.base_model, args.num_segments) if os.path.isfile(args.resume): print(("=> loading checkpoint '{}'".format(args.resume))) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] #best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print(("=> loaded checkpoint '{}' (epoch {})".format( args.evaluate, checkpoint['epoch']))) else: print(("=> no checkpoint found at '{}'".format(args.resume))) if args.tune_from: print(("=> fine-tuning from '{}'".format(args.tune_from))) sd = torch.load(args.tune_from) sd = sd['state_dict'] model_dict = model.state_dict() replace_dict = [] for k, v in sd.items(): if k not in model_dict and k.replace('.net', '') in model_dict: print('=> Load after remove .net: ', k) replace_dict.append((k, k.replace('.net', ''))) for k, v in model_dict.items(): if k not in sd and k.replace('.net', '') in sd: print('=> Load after adding .net: ', k) replace_dict.append((k.replace('.net', ''), k)) for k, k_new in replace_dict: sd[k_new] = sd.pop(k) keys1 = set(list(sd.keys())) keys2 = set(list(model_dict.keys())) set_diff = (keys1 - keys2) | (keys2 - keys1) print('#### Notice: keys that failed to load: {}'.format(set_diff)) if args.dataset not in args.tune_from: # new dataset print('=> New dataset, do not load fc weights') sd = {k: v for k, v in sd.items() if 'fc' not in k} if args.modality == 'Flow' and 'Flow' not in args.tune_from: sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k} model_dict.update(sd) model.load_state_dict(model_dict) if args.temporal_pool and not args.resume: make_temporal_pool(model.module.base_model, args.num_segments) cudnn.benchmark = True # Data loading code if args.modality != 'RGBDiff': normalize = GroupNormalize(input_mean, input_std) else: normalize = IdentityTransform() if args.modality == 'RGB': data_length = 1 elif args.modality in ['Flow', 'RGBDiff']: data_length = 5 train_loader = torch.utils.data.DataLoader(TSNDataSet( "", args.train_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl="frame_{:04d}.jpg" if args.modality in ["RGB", "RGBDiff"] else args.flow_prefix + "{}_{:05d}.jpg", transform=torchvision.transforms.Compose([ train_augmentation, Stack(roll=args.arch == 'BNInception'), ToTorchFormatTensor(div=args.arch != 'BNInception'), normalize, ])), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(TSNDataSet( "", args.val_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl="frame_{:04d}.jpg" if args.modality in ["RGB", "RGBDiff"] else args.flow_prefix + "{}_{:05d}.jpg", random_shift=False, transform=torchvision.transforms.Compose([ GroupScale(int(scale_size)), GroupCenterCrop(crop_size), Stack(roll=args.arch == 'BNInception'), ToTorchFormatTensor(div=args.arch != 'BNInception'), normalize, ])), batch_size=int(16), shuffle=False, num_workers=args.workers, pin_memory=True) # define loss function (criterion) and optimizer criterion = torch.nn.BCEWithLogitsLoss().cuda() for group in policies: print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format( group['name'], len(group['params']), group['lr_mult'], group['decay_mult']))) optimizer = torch.optim.SGD(policies, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) zero_time = time.time() best_map = 0 print('Start training...') for epoch in range(args.start_epoch, args.epochs): #adjust_learning_rate(optimizer, epoch, args.lr_steps) #if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: valloss, mAP, output_mtx = validate(val_loader, model, criterion) end_time = time.time() epoch_time = end_time - start_time total_time = end_time - zero_time print('Total time used: %s Epoch %d time uesd: %s' % (str(datetime.timedelta(seconds=int(total_time))), epoch, str(datetime.timedelta(seconds=int(epoch_time))))) print('Train loss: {0:.4f} val loss: {1:.4f} mAP: {2:.4f}'.format( trainloss, valloss, mAP)) # evaluate on validation set is_best = mAP > best_map #if mAP > best_map: #best_map = mAP # checkpoint_name = "%04d_%s" % (epoch+1, "checkpoint.pth.tar") checkpoint_name = "best_checkpoint.pth.tar" save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, is_best, epoch) np.save("val1.npy", output_mtx) with open(args.record_path, 'a') as file: file.write( 'Epoch:[{0}]' 'Train loss: {1:.4f} val loss: {2:.4f} map: {3:.4f}\n'.format( epoch + 1, trainloss, valloss, mAP)) print('************ Done!... ************')