def __init__(self, modality, checkpoint, arena_mask_path): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = TSN(2, 8, modality, base_model='resnet18', consensus_type='avg', dropout=0.5, img_feature_dim=256, partial_bn=False, pretrain='imagenet', is_shift=True, shift_div=8, shift_place='blockres', fc_lr5=False, temporal_pool=False, non_local=False) # Get Model complexity macs, params = get_model_complexity_info( model, (24, 224, 224), as_strings=True, print_per_layer_stat=False, verbose=True) # noqa: E128, E501 print('---{:<30} {:<8}'.format('Computational complexity: ', macs)) print('{:<30} {:<8}'.format('Number of parameters: ', params)) # Define transforms crop_size = model.crop_size scale_size = model.scale_size input_mean = model.input_mean input_std = model.input_std self.transform = torchvision.transforms.Compose([ GroupScale(int(scale_size)), GroupCenterCrop(crop_size), Stack(roll=False), ToTorchFormatTensor(div=True), GroupNormalize(input_mean, input_std), ]) # Load TSM model model = torch.nn.DataParallel(model, device_ids=1).to(device) model.load_state_dict( torch.load(checkpoint, map_location=device)['state_dict']) self.model = model self.model.eval() # Frame samples to be selected in a clip self.action_names = ['explore', 'investigate'] self.rgb_sample = [2, 6, 9, 13, 17, 20, 24, 28] # [4, 12, 19, 26, 34, 41, 48, 56] self.arena_mask = cv2.imread(arena_mask_path) if self.arena_mask is None: print("Arena Mask not loaded: %s" % arena_mask_path) exit(0)
def __init__(self, modality, checkpoint): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = TSN(2, 8, modality, base_model='resnet18', consensus_type='avg', dropout=0.5, img_feature_dim=256, partial_bn=False, pretrain='imagenet', is_shift=True, shift_div=8, shift_place='blockres', fc_lr5=False, temporal_pool=False, non_local=False) ## Define transforms crop_size = model.crop_size scale_size = model.scale_size input_mean = model.input_mean input_std = model.input_std self.transform = torchvision.transforms.Compose([ GroupScale(int(scale_size)), GroupCenterCrop(crop_size), Stack(roll=False), ToTorchFormatTensor(div=True), GroupNormalize(input_mean, input_std), ]) ## Load TSM model model = torch.nn.DataParallel(model, device_ids=1).to(device) model.load_state_dict(torch.load(checkpoint)['state_dict']) self.model = model self.model.eval() self.action_names = ['explore', 'investigate'] self.rgb_sample = [4, 12, 19, 26, 34, 41, 48, 56]
def load_src_model(): model = TSN(2, 8, 'RGB', base_model='resnet50', consensus_type='avg', img_feature_dim=256, pretrain='imagenet', is_shift=True, shift_div=8, shift_place='blockres', non_local=False, ) modelpath = '/nfs/volume-95-7/temporal-shift-module/checkpoint/TSM_videos_1218_RGB_resnet50_shift8_blockres_avg_segment8_e120_pr8_ext0.1/ckpt.best.pth.tar' checkpoint = torch.load(modelpath) checkpoint = checkpoint['state_dict'] # base_dict = {('base_model.' + k).replace('base_model.fc', 'new_fc'): v for k, v in list(checkpoint.items())} base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())} replace_dict = {'base_model.classifier.weight': 'new_fc.weight', 'base_model.classifier.bias': 'new_fc.bias', } for k, v in replace_dict.items(): if k in base_dict: base_dict[v] = base_dict.pop(k) model.load_state_dict(base_dict) model.eval() # example = torch.ones(1, 8, 3, 224, 224) example = torch.eye(224) example = example.expand((1, 8, 3, 224, 224)) y = model(example) print("src_model output: ", y)
def cvt_model(): print("===> Loading model") model = TSN(2, 8, 'RGB', base_model='resnet50', consensus_type='avg', img_feature_dim=256, pretrain='imagenet', is_shift=True, shift_div=8, shift_place='blockres', non_local=False, ) modelpath = '/nfs/volume-95-7/temporal-shift-module/checkpoint/TSM_videos_1218_RGB_resnet50_shift8_blockres_avg_segment8_e120_pr8_ext0.1/ckpt.best.pth.tar' checkpoint = torch.load(modelpath) checkpoint = checkpoint['state_dict'] # base_dict = {('base_model.' + k).replace('base_model.fc', 'new_fc'): v for k, v in list(checkpoint.items())} base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())} replace_dict = {'base_model.classifier.weight': 'new_fc.weight', 'base_model.classifier.bias': 'new_fc.bias', } for k, v in replace_dict.items(): if k in base_dict: base_dict[v] = base_dict.pop(k) model.load_state_dict(base_dict) # 模型转换,Torch Script model.cuda() model.eval() example = torch.rand(1,8,3,224,224).cuda() y = model(example) print(y.shape) traced_script_module = torch.jit.trace(model, example) print(traced_script_module.code) # traced_script_module = torch.jit.script(model) # print(traced_script_module.code) # output = traced_script_module(torch.rand(1, 1, 224, 224)) traced_script_module.save("tsm_with_1218.pt") print("Export of model.pt complete!")
def load_model(weights): global num_class is_shift, shift_div, shift_place = parse_shift_option_from_log_name( weights) if 'RGB' in weights: modality = 'RGB' elif 'Depth' in weights: modality = 'Depth' else: modality = 'Flow' if 'concatAll' in weights: concat = "All" elif "concatFirst" in weights: concat = "First" else: concat = "" if 'extra' in this_weights: extra_temporal_modeling = True args.prune = "" if 'conv1d' in weights: args.crop_fusion_type = "conv1d" else: args.crop_fusion_type = "avg" this_arch = weights.split('TSM_')[1].split('_')[2] modality_list.append(modality) num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset( args.dataset, modality) print('=> shift: {}, shift_div: {}, shift_place: {}'.format( is_shift, shift_div, shift_place)) net = TSN(num_class, int(args.test_segments) if is_shift else 1, modality, base_model=this_arch, consensus_type=args.crop_fusion_type, img_feature_dim=args.img_feature_dim, pretrain=args.pretrain, is_shift=is_shift, shift_div=shift_div, shift_place=shift_place, non_local='_nl' in weights, concat=concat, extra_temporal_modeling=extra_temporal_modeling, prune_list=[prune_conv1in_list, prune_conv1out_list], is_prune=args.prune) if 'tpool' in weights: from ops.temporal_shift import make_temporal_pool make_temporal_pool(net.base_model, args.test_segments) # since DataParallel checkpoint = torch.load(weights) checkpoint = checkpoint['state_dict'] # base_dict = {('base_model.' + k).replace('base_model.fc', 'new_fc'): v for k, v in list(checkpoint.items())} base_dict = { '.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items()) } replace_dict = { 'base_model.classifier.weight': 'new_fc.weight', 'base_model.classifier.bias': 'new_fc.bias', } for k, v in replace_dict.items(): if k in base_dict: base_dict[v] = base_dict.pop(k) net.load_state_dict(base_dict) input_size = net.scale_size if args.full_res else net.input_size if args.test_crops == 1: cropping = torchvision.transforms.Compose([ GroupScale(net.scale_size), GroupCenterCrop(input_size), ]) elif args.test_crops == 3: # do not flip, so only 5 crops cropping = torchvision.transforms.Compose( [GroupFullResSample(input_size, net.scale_size, flip=False)]) elif args.test_crops == 5: # do not flip, so only 5 crops cropping = torchvision.transforms.Compose( [GroupOverSample(input_size, net.scale_size, flip=False)]) elif args.test_crops == 10: cropping = torchvision.transforms.Compose( [GroupOverSample(input_size, net.scale_size)]) else: raise ValueError( "Only 1, 5, 10 crops are supported while we got {}".format( args.test_crops)) transform = torchvision.transforms.Compose([ cropping, Stack(roll=(this_arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor( div=(this_arch not in ['BNInception', 'InceptionV3'])), GroupNormalize(net.input_mean, net.input_std), ]) if args.gpus is not None: devices = [args.gpus[i] for i in range(args.workers)] else: devices = list(range(args.workers)) net = torch.nn.DataParallel(net.cuda()) return is_shift, net, transform
def main(): global args, best_prec1, TRAIN_SAMPLES args = parser.parse_args() num_class, args.train_list, args.val_list, args.test_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset, args.modality) if os.path.exists(args.test_list): args.val_list = args.test_list model = TSN(num_class, args.num_segments, args.modality, base_model=args.arch, consensus_type=args.consensus_type, dropout=args.dropout, img_feature_dim=args.img_feature_dim, partial_bn=not args.no_partialbn, pretrain=args.pretrain, is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place, fc_lr5=not (args.tune_from and args.dataset in args.tune_from), temporal_pool=args.temporal_pool, non_local=args.non_local, tin=args.tin) crop_size = args.crop_size scale_size = args.scale_size input_mean = [0.485, 0.456, 0.406] input_std = [0.229, 0.224, 0.225] print(args.gpus) model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda() if os.path.isfile(args.resume_path): print(("=> loading checkpoint '{}'".format(args.resume_path))) checkpoint = torch.load(args.resume_path) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict'], strict=False) print(("=> loaded checkpoint '{}' (epoch {})" .format(args.evaluate, checkpoint['epoch']))) else: print(("=> no checkpoint found at '{}'".format(args.resume_path))) cudnn.benchmark = True # Data loading code if args.modality != 'RGBDiff': normalize = GroupNormalize(input_mean, input_std) else: normalize = IdentityTransform() if args.modality == 'RGB': data_length = 1 elif args.modality in ['Flow', 'RGBDiff']: data_length = 5 if args.random_crops == 1: crop_aug = GroupCenterCrop(args.crop_size) elif args.random_crops == 3: crop_aug = GroupFullResSample(args.crop_size, args.scale_size, flip=False) elif args.random_crops == 5: crop_aug = GroupOverSample(args.crop_size, args.scale_size, flip=False) else: crop_aug = MultiGroupRandomCrop(args.crop_size, args.random_crops), test_dataset = TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, multi_class=args.multi_class, transform=torchvision.transforms.Compose([ GroupScale(int(args.scale_size)), crop_aug, Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize, ]), dense_sample=args.dense_sample, test_mode=True, temporal_clips=args.temporal_clips) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) test(test_loader, model, args.start_epoch)
def main(): global args, best_prec1 args = parser.parse_args() torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset( args.dataset, args.modality) full_arch_name = args.arch args.store_name = '_'.join([ 'TDN_', args.dataset, args.modality, full_arch_name, args.consensus_type, 'segment%d' % args.num_segments, 'e{}'.format(args.epochs) ]) if args.pretrain != 'imagenet': args.store_name += '_{}'.format(args.pretrain) if args.dense_sample: args.store_name += '_dense' if args.suffix is not None: args.store_name += '_{}'.format(args.suffix) if dist.get_rank() == 0: check_rootfolders() logger = setup_logger(output=os.path.join(args.root_log, args.store_name), distributed_rank=dist.get_rank(), name=f'TDN') logger.info('storing name: ' + args.store_name) model = TSN(num_class, args.num_segments, args.modality, base_model=args.arch, consensus_type=args.consensus_type, dropout=args.dropout, img_feature_dim=args.img_feature_dim, partial_bn=not args.no_partialbn, pretrain=args.pretrain, fc_lr5=(args.tune_from and args.dataset in args.tune_from)) crop_size = model.crop_size scale_size = model.scale_size input_mean = model.input_mean input_std = model.input_std policies = model.get_optim_policies() for group in policies: logger.info( ('[TDN-{}]group: {} has {} params, lr_mult: {}, decay_mult: {}'. format(args.arch, group['name'], len(group['params']), group['lr_mult'], group['decay_mult']))) train_augmentation = model.get_augmentation( flip=False if 'something' in args.dataset else True) cudnn.benchmark = True # Data loading code normalize = GroupNormalize(input_mean, input_std) train_dataset = TSNDataSet( args.dataset, args.root_path, args.train_list, num_segments=args.num_segments, modality=args.modality, image_tmpl=prefix, transform=torchvision.transforms.Compose([ train_augmentation, Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor( div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize, ]), dense_sample=args.dense_sample) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) val_dataset = TSNDataSet( args.dataset, args.root_path, args.val_list, num_segments=args.num_segments, modality=args.modality, image_tmpl=prefix, random_shift=False, transform=torchvision.transforms.Compose([ GroupScale(int(scale_size)), GroupCenterCrop(crop_size), Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor( div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize, ]), dense_sample=args.dense_sample) val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True, sampler=val_sampler, drop_last=True) # define loss function (criterion) and optimizer if args.loss_type == 'nll': criterion = torch.nn.CrossEntropyLoss().cuda() else: raise ValueError("Unknown loss type") optimizer = torch.optim.SGD(policies, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) scheduler = get_scheduler(optimizer, len(train_loader), args) model = DistributedDataParallel(model.cuda(), device_ids=[args.local_rank], broadcast_buffers=True, find_unused_parameters=True) if args.resume: if os.path.isfile(args.resume): logger.info(("=> loading checkpoint '{}'".format(args.resume))) checkpoint = torch.load(args.resume, map_location='cpu') args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) logger.info(("=> loaded checkpoint '{}' (epoch {})".format( args.evaluate, checkpoint['epoch']))) else: logger.info(("=> no checkpoint found at '{}'".format(args.resume))) if args.tune_from: logger.info(("=> fine-tuning from '{}'".format(args.tune_from))) sd = torch.load(args.tune_from) sd = sd['state_dict'] model_dict = model.state_dict() replace_dict = [] for k, v in sd.items(): if k not in model_dict and k.replace('.net', '') in model_dict: logger.info('=> Load after remove .net: ', k) replace_dict.append((k, k.replace('.net', ''))) for k, v in model_dict.items(): if k not in sd and k.replace('.net', '') in sd: logger.info('=> Load after adding .net: ', k) replace_dict.append((k.replace('.net', ''), k)) for k, k_new in replace_dict: sd[k_new] = sd.pop(k) keys1 = set(list(sd.keys())) keys2 = set(list(model_dict.keys())) set_diff = (keys1 - keys2) | (keys2 - keys1) logger.info( '#### Notice: keys that failed to load: {}'.format(set_diff)) if args.dataset not in args.tune_from: # new dataset logger.info('=> New dataset, do not load fc weights') sd = {k: v for k, v in sd.items() if 'fc' not in k} model_dict.update(sd) model.load_state_dict(model_dict) with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f: f.write(str(args)) tf_writer = SummaryWriter( log_dir=os.path.join(args.root_log, args.store_name)) if args.evaluate: logger.info(("===========evaluate===========")) val_loader.sampler.set_epoch(args.start_epoch) prec1, prec5, val_loss = validate(val_loader, model, criterion, logger) if dist.get_rank() == 0: is_best = prec1 > best_prec1 best_prec1 = prec1 logger.info(("Best Prec@1: '{}'".format(best_prec1))) save_epoch = args.start_epoch + 1 save_checkpoint( { 'epoch': args.start_epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'prec1': prec1, 'best_prec1': best_prec1, }, save_epoch, is_best) return for epoch in range(args.start_epoch, args.epochs): train_loader.sampler.set_epoch(epoch) train_loss, train_top1, train_top5 = train(train_loader, model, criterion, optimizer, epoch=epoch, logger=logger, scheduler=scheduler) if dist.get_rank() == 0: tf_writer.add_scalar('loss/train', train_loss, epoch) tf_writer.add_scalar('acc/train_top1', train_top1, epoch) tf_writer.add_scalar('acc/train_top5', train_top5, epoch) tf_writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], epoch) if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: val_loader.sampler.set_epoch(epoch) prec1, prec5, val_loss = validate(val_loader, model, criterion, epoch, logger) if dist.get_rank() == 0: tf_writer.add_scalar('loss/test', val_loss, epoch) tf_writer.add_scalar('acc/test_top1', prec1, epoch) tf_writer.add_scalar('acc/test_top5', prec5, epoch) is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch) logger.info(("Best Prec@1: '{}'".format(best_prec1))) tf_writer.flush() save_epoch = epoch + 1 save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'prec1': prec1, 'best_prec1': best_prec1, }, save_epoch, is_best)
def main(): # options parser = argparse.ArgumentParser(description="TSM testing on the full validation set") parser.add_argument('dataset', type=str) # may contain splits parser.add_argument('--weights', type=str, default=None) parser.add_argument('--test_segments', type=str, default=25) parser.add_argument('--dense_sample', default=False, action="store_true", help='use dense sample as I3D') parser.add_argument('--twice_sample', default=False, action="store_true", help='use twice sample for ensemble') parser.add_argument('--full_res', default=False, action="store_true", help='use full resolution 256x256 for test as in Non-local I3D') parser.add_argument('--test_crops', type=int, default=1) parser.add_argument('--coeff', type=str, default=None) parser.add_argument('--batch_size', type=int, default=1) parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', help='number of data loading workers (default: 8)') # for true test parser.add_argument('--test_list', type=str, default=None) parser.add_argument('--csv_file', type=str, default=None) parser.add_argument('--softmax', default=False, action="store_true", help='use softmax') parser.add_argument('--max_num', type=int, default=-1) parser.add_argument('--input_size', type=int, default=224) parser.add_argument('--crop_fusion_type', type=str, default='avg') parser.add_argument('--gpus', nargs='+', type=int, default=None) parser.add_argument('--img_feature_dim',type=int, default=256) parser.add_argument('--num_set_segments',type=int, default=1,help='TODO: select multiply set of n-frames from a video') parser.add_argument('--pretrain', type=str, default='imagenet') args = parser.parse_args() def accuracy(output, target, topk=(1,)): """Computes the precision@k for the specified values of k""" maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0) res.append(correct_k.mul_(100.0 / batch_size)) return res def parse_shift_option_from_log_name(log_name): if 'shift' in log_name: strings = log_name.split('_') for i, s in enumerate(strings): if 'shift' in s: break return True, int(strings[i].replace('shift', '')), strings[i + 1] else: return False, None, None weights_list = args.weights.split(',') test_segments_list = [int(s) for s in args.test_segments.split(',')] assert len(weights_list) == len(test_segments_list) if args.coeff is None: coeff_list = [1] * len(weights_list) else: coeff_list = [float(c) for c in args.coeff.split(',')] if args.test_list is not None: test_file_list = args.test_list.split(',') else: test_file_list = [None] * len(weights_list) data_iter_list = [] net_list = [] modality_list = [] total_num = None for this_weights, this_test_segments, test_file in zip(weights_list, test_segments_list, test_file_list): is_shift, shift_div, shift_place = parse_shift_option_from_log_name(this_weights) if 'RGB' in this_weights: modality = 'RGB' else: modality = 'Flow' this_arch = this_weights.split('TSM_')[1].split('_')[2] modality_list.append(modality) num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset(args.dataset, modality) print('=> shift: {}, shift_div: {}, shift_place: {}'.format(is_shift, shift_div, shift_place)) net = TSN(num_class, this_test_segments if is_shift else 1, modality, base_model=this_arch, consensus_type=args.crop_fusion_type, img_feature_dim=args.img_feature_dim, pretrain=args.pretrain, is_shift=is_shift, shift_div=shift_div, shift_place=shift_place, non_local='_nl' in this_weights, ) if 'tpool' in this_weights: from ops.temporal_shift import make_temporal_pool make_temporal_pool(net.base_model, this_test_segments) # since DataParallel checkpoint = torch.load(this_weights) checkpoint = checkpoint['state_dict'] # base_dict = {('base_model.' + k).replace('base_model.fc', 'new_fc'): v for k, v in list(checkpoint.items())} base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())} replace_dict = {'base_model.classifier.weight': 'new_fc.weight', 'base_model.classifier.bias': 'new_fc.bias', } for k, v in replace_dict.items(): if k in base_dict: base_dict[v] = base_dict.pop(k) net.load_state_dict(base_dict) input_size = net.scale_size if args.full_res else net.input_size if args.test_crops == 1: cropping = torchvision.transforms.Compose([ GroupScale(net.scale_size), GroupCenterCrop(input_size), ]) elif args.test_crops == 3: # do not flip, so only 5 crops cropping = torchvision.transforms.Compose([ GroupFullResSample(input_size, net.scale_size, flip=False) ]) elif args.test_crops == 5: # do not flip, so only 5 crops cropping = torchvision.transforms.Compose([ GroupOverSample(input_size, net.scale_size, flip=False) ]) elif args.test_crops == 10: cropping = torchvision.transforms.Compose([ GroupOverSample(input_size, net.scale_size) ]) else: raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(args.test_crops)) data_loader = torch.utils.data.DataLoader( TSNDataSet(root_path, test_file if test_file is not None else val_list, num_segments=this_test_segments, new_length=1 if modality == "RGB" else 5, modality=modality, image_tmpl=prefix, test_mode=True, remove_missing=len(weights_list) == 1, transform=torchvision.transforms.Compose([ cropping, Stack(roll=(this_arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor(div=(this_arch not in ['BNInception', 'InceptionV3'])), GroupNormalize(net.input_mean, net.input_std), ]), dense_sample=args.dense_sample, twice_sample=args.twice_sample), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, ) if args.gpus is not None: devices = [args.gpus[i] for i in range(args.workers)] else: devices = list(range(args.workers)) net = torch.nn.DataParallel(net.cuda()) net.eval() data_gen = enumerate(data_loader) if total_num is None: total_num = len(data_loader.dataset) else: assert total_num == len(data_loader.dataset) data_iter_list.append(data_gen) net_list.append(net) output = [] def eval_video(video_data, net, this_test_segments, modality): net.eval() with torch.no_grad(): i, data, label = video_data batch_size = label.numel() num_crop = args.test_crops if args.dense_sample: num_crop *= 10 # 10 clips for testing when using dense sample if args.twice_sample: num_crop *= 2 if modality == 'RGB': length = 3 elif modality == 'Flow': length = 10 elif modality == 'RGBDiff': length = 18 else: raise ValueError("Unknown modality "+ modality) data_in = data.view(-1, length, data.size(2), data.size(3)) if is_shift: data_in = data_in.view(batch_size * num_crop, this_test_segments, length, data_in.size(2), data_in.size(3)) rst = net(data_in) rst = rst.reshape(batch_size, num_crop, -1).mean(1) if args.softmax: # take the softmax to normalize the output to probability rst = F.softmax(rst, dim=1) rst = rst.data.cpu().numpy().copy() if net.module.is_shift: rst = rst.reshape(batch_size, num_class) else: rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class)) return i, rst, label proc_start_time = time.time() max_num = args.max_num if args.max_num > 0 else total_num top1 = AverageMeter() top5 = AverageMeter() for i, data_label_pairs in enumerate(zip(*data_iter_list)): with torch.no_grad(): if i >= max_num: break this_rst_list = [] this_label = None for n_seg, (_, (data, label)), net, modality in zip(test_segments_list, data_label_pairs, net_list, modality_list): rst = eval_video((i, data, label), net, n_seg, modality) this_rst_list.append(rst[1]) this_label = label assert len(this_rst_list) == len(coeff_list) for i_coeff in range(len(this_rst_list)): this_rst_list[i_coeff] *= coeff_list[i_coeff] ensembled_predict = sum(this_rst_list) / len(this_rst_list) for p, g in zip(ensembled_predict, this_label.cpu().numpy()): output.append([p[None, ...], g]) cnt_time = time.time() - proc_start_time prec1, prec5 = accuracy(torch.from_numpy(ensembled_predict), this_label, topk=(1, 5)) top1.update(prec1.item(), this_label.numel()) top5.update(prec5.item(), this_label.numel()) if i % 20 == 0: print('video {} done, total {}/{}, average {:.3f} sec/video, ' 'moving Prec@1 {:.3f} Prec@5 {:.3f}'.format(i * args.batch_size, i * args.batch_size, total_num, float(cnt_time) / (i+1) / args.batch_size, top1.avg, top5.avg)) video_pred = [np.argmax(x[0]) for x in output] video_pred_top5 = [np.argsort(np.mean(x[0], axis=0).reshape(-1))[::-1][:5] for x in output] video_labels = [x[1] for x in output] if args.csv_file is not None: print('=> Writing result to csv file: {}'.format(args.csv_file)) with open(test_file_list[0].replace('test_videofolder.txt', 'category.txt')) as f: categories = f.readlines() categories = [f.strip() for f in categories] with open(test_file_list[0]) as f: vid_names = f.readlines() vid_names = [n.split(' ')[0] for n in vid_names] print(vid_names) assert len(vid_names) == len(video_pred) if args.dataset != 'somethingv2': # only output top1 with open(args.csv_file, 'w') as f: for n, pred in zip(vid_names, video_pred): f.write('{};{}\n'.format(n, categories[pred])) else: with open(args.csv_file, 'w') as f: for n, pred5 in zip(vid_names, video_pred_top5): fill = [n] for p in list(pred5): fill.append(p) f.write('{};{};{};{};{};{}\n'.format(*fill)) cf = confusion_matrix(video_labels, video_pred).astype(float) np.save('cm.npy', cf) cls_cnt = cf.sum(axis=1) cls_hit = np.diag(cf) cls_acc = cls_hit / cls_cnt print(cls_acc) upper = np.mean(np.max(cf, axis=1) / cls_cnt) print('upper bound: {}'.format(upper)) print('-----Evaluation is finished------') print('Class Accuracy {:.02f}%'.format(np.mean(cls_acc) * 100)) print('Overall Prec@1 {:.02f}% Prec@5 {:.02f}%'.format(top1.avg, top5.avg))
def get_tsm(num_classes=3, pretrain_set='kinetics'): if pretrain_set == 'kinetics': base_model = "resnet50" this_weights = "pretrained/TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e100_dense_nl.pth" original_num_classes = 400 non_local = True print("Using kinetics") else: base_model = "resnet101" this_weights = "pretrained/TSM_somethingv2_RGB_resnet101_shift8_blockres_avg_segment8_e45.pth" # base_model = "resnet50" # this_weights = "pretrained/TSM_somethingv2_RGB_resnet50_shift8_blockres_avg_segment8_e45.pth" original_num_classes = 174 non_local = False modality = "RGB" segments = 8 consensus_type = "avg" img_feature_dim = 256 pretrain = True is_shift = True shift_div = 8 shift_place = "blockres" net = TSN( original_num_classes, segments, modality, base_model=base_model, consensus_type=consensus_type, img_feature_dim=img_feature_dim, pretrain=pretrain, is_shift=is_shift, shift_div=shift_div, shift_place=shift_place, non_local=non_local, ) checkpoint = torch.load(this_weights) checkpoint = checkpoint['state_dict'] base_dict = { '.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items()) } replace_dict = { 'base_model.classifier.weight': 'new_fc.weight', 'base_model.classifier.bias': 'new_fc.bias', } for k, v in replace_dict.items(): if k in base_dict: base_dict[v] = base_dict.pop(k) net.load_state_dict(base_dict) # # for param in net.parameters(): # param.requires_grad = False # # for param in net.base_model.layer4.parameters(): # param.requires_grad = True net.new_fc = torch.nn.Linear(2048, num_classes) return net
def main_worker(gpu, ngpus_per_node, args): global best_acc1 args.gpu = gpu if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) if args.distributed: if args.dist_url == "env://" and args.rank == -1: args.rank = int(os.environ["RANK"]) if args.multiprocessing_distributed: # For multiprocessing distributed training, rank needs to be the # global rank among all the processes rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=rank) else: rank = 0 # create model num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset, args.modality) full_arch_name = args.arch if args.shift: full_arch_name += '_shift{}_{}'.format(args.shift_div, args.shift_place) if args.temporal_pool: full_arch_name += '_tpool' args.store_name = '_'.join( ['TSM', args.dataset, args.modality, full_arch_name, args.consensus_type, 'segment%d' % args.num_segments, 'e{}'.format(args.epochs)]) if args.pretrain != 'imagenet': args.store_name += '_{}'.format(args.pretrain) if args.lr_type != 'step': args.store_name += '_{}'.format(args.lr_type) if args.dense_sample: args.store_name += '_dense' if args.non_local > 0: args.store_name += '_nl' args.store_name += '_lr{}'.format(args.lr) args.store_name += '_wd{:.1e}'.format(args.weight_decay) args.store_name += '_do{}'.format(args.dropout) if args.suffix is not None: args.store_name += '_{}'.format(args.suffix) print('storing name: ' + args.store_name) check_rootfolders(args, rank) model = TSN(num_class, args.num_segments, args.modality, base_model=args.arch, consensus_type=args.consensus_type, dropout=args.dropout, img_feature_dim=args.img_feature_dim, partial_bn=not args.no_partialbn, pretrain=args.pretrain, is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place, fc_lr5=not (args.tune_from and args.dataset in args.tune_from), temporal_pool=args.temporal_pool, non_local=args.non_local) # first synchronization of initial weights # sync_initial_weights(model, args.rank, args.world_size) model = nn.SyncBatchNorm.convert_sync_batchnorm(model) if rank == 0: print(model) crop_size = model.crop_size scale_size = model.scale_size input_mean = model.input_mean input_std = model.input_std policies = model.get_optim_policies() train_augmentation = model.get_augmentation(flip=False if 'something' in args.dataset or 'jester' in args.dataset else True) if args.distributed: # For multiprocessing distributed, DistributedDataParallel constructor # should always set the single device scope, otherwise, # DistributedDataParallel will use all available devices. if args.gpu is not None: torch.cuda.set_device(args.gpu) model.cuda(args.gpu) # When using a single GPU per process and per # DistributedDataParallel, we need to divide the batch size # ourselves based on the total number of GPUs we have on a node args.batch_size = int(args.batch_size / ngpus_per_node) args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) else: model.cuda() # DistributedDataParallel will divide and allocate batch_size to all # available GPUs if device_ids are not set model = torch.nn.parallel.DistributedDataParallel(model) elif args.gpu is not None: torch.cuda.set_device(args.gpu) model = model.cuda(args.gpu) else: # DataParallel will divide and allocate batch_size to all available GPUs if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): model.features = torch.nn.DataParallel(model.features) model.cuda() else: model = torch.nn.DataParallel(model).cuda() # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda(args.gpu) optimizer = torch.optim.SGD(policies, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) if args.gpu is None: checkpoint = torch.load(args.resume) else: # Map model to be loaded to specified single gpu. loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.resume, map_location=loc) if args.start_epoch == 1: args.start_epoch = checkpoint['epoch'] + 1 best_acc1 = checkpoint['best_acc1'] # if args.gpu is not None: # # best_acc1 may be from a checkpoint from a different GPU # best_acc1 = best_acc1.to(args.gpu) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True # Data loading code if args.modality != 'RGBDiff': normalize = GroupNormalize(input_mean, input_std) else: normalize = IdentityTransform() if args.modality == 'RGB': data_length = 1 elif args.modality in ['Flow', 'RGBDiff']: data_length = 5 train_dataset = TSNDataSet(args.dataset, args.root_path, args.train_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, transform=torchvision.transforms.Compose([train_augmentation, Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize]), dense_sample=args.dense_sample) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, drop_last=True, sampler=train_sampler) val_loader = torch.utils.data.DataLoader( TSNDataSet(args.dataset, args.root_path, args.val_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, random_shift=False, transform=torchvision.transforms.Compose([ GroupScale(int(scale_size)), GroupCenterCrop(crop_size), Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize, ]), dense_sample=args.dense_sample), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) if args.evaluate: validate(val_loader, model, criterion, args) return log_training = open(os.path.join(args.root_model, args.store_name, 'log.csv'), 'w') with open(os.path.join(args.root_model, args.store_name, 'args.txt'), 'w') as f: f.write(str(args)) tf_writer = SummaryWriter(log_dir=os.path.join(args.root_model, args.store_name)) for epoch in range(args.start_epoch, args.epochs+1): if args.distributed: train_sampler.set_epoch(epoch) # adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps, args) # train for one epoch train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer, args, rank) if rank % ngpus_per_node == 0: save_checkpoint({ 'epoch': epoch, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_acc1': best_acc1, 'optimizer' : optimizer.state_dict(), }, False, args, rank) if epoch % 5 == 0 and rank % ngpus_per_node == 0: save_checkpoint({ 'epoch': epoch, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_acc1': best_acc1, 'optimizer': optimizer.state_dict(), }, False, args, rank, e=epoch) # evaluate on validation set is_best = False if epoch % args.eval_freq == 0 or epoch == args.epochs: acc1 = validate(val_loader, model, criterion, epoch, args, rank, log_training, tf_writer) # remember best acc@1 and save checkpoint is_best = acc1 > best_acc1 best_acc1 = max(acc1, best_acc1) if not args.multiprocessing_distributed or (args.multiprocessing_distributed and rank % ngpus_per_node == 0): save_checkpoint({ 'epoch': epoch, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_acc1': best_acc1, 'optimizer' : optimizer.state_dict(), }, is_best, args, rank)
has_tam = parse_shift_option_from_log_name(this_weights) if 'RGB' in this_weights: modality = 'RGB' else: modality = 'Flow' this_arch = this_weights.split('/')[-2].split('_')[2] modality_list.append(modality) num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset( args.dataset, modality) print('=> TAM : {}, {} sampling'.format(has_tam, args.sample)) net = TSN( num_class, this_test_segments if has_tam else 1, modality, base_model=this_arch, consensus_type=args.crop_fusion_type, img_feature_dim=args.img_feature_dim, pretrain=args.pretrain, tam=True, non_local='_nl' in this_weights, ) sf_ckpt = torch.load(this_weights, map_location=torch.device('cpu')) sf_weights = sf_ckpt['state_dict'] tam_ckpt = net.state_dict() # print(base_dict.keys()) # exit() base_dict = {} for k, v in sf_weights.items(): if 'self_conv.conv_f' in k: k = k.replace('self_conv.conv_f', 'tam.G')
def main(): # settings global args, best_prec1 args = parser.parse_args() n_class, args.train_list, args.val_list, args.test_list, prefix = dataset_config.dataset( ) full_arch_name = args.arch if args.shift: full_arch_name += '_shift{}'.format(args.shift_div) args.store_name = '_'.join( ['tsm', full_arch_name, 'segment%d' % args.num_segments]) print('storing name: ' + args.store_name) check_rootfolders(args.root_log, args.root_model, args.store_name) # tsn model added temporal shift module model = TSN(n_class, args.num_segments, base_model=args.arch, dropout=args.dropout, partial_bn=not args.no_partialbn, is_shift=args.shift, shift_div=args.shift_div) # preprocessing for input crop_size = model.crop_size scale_size = model.scale_size input_mean = model.input_mean input_std = model.input_std policies = model.get_optim_policies() train_augmentation = model.get_augmentation(flip=False) # optimizer optimizer = torch.optim.SGD(policies, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # cuda and cudnn try: model = nn.DataParallel(model).cuda() except: model = model.cuda() cudnn.benchmark = True # data loader normalize = GroupNormalize(input_mean, input_std) train_loader = torch.utils.data.DataLoader(TSNDataSet( args.train_list, num_segments=args.num_segments, image_tmpl=prefix, transform=torchvision.transforms.Compose([ train_augmentation, Stack(roll=False), ToTorchFormatTensor(div=True), normalize ])), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=False, drop_last=True) val_loader = torch.utils.data.DataLoader(TSNDataSet( args.val_list, num_segments=args.num_segments, image_tmpl=prefix, random_shift=False, transform=torchvision.transforms.Compose([ GroupScale(int(scale_size)), GroupCenterCrop(crop_size), Stack(roll=False), ToTorchFormatTensor(div=True), normalize ])), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=False) test_loader = torch.utils.data.DataLoader(TSNDataSet( args.test_list, num_segments=args.num_segments, image_tmpl=prefix, random_shift=False, test_mode=True, transform=torchvision.transforms.Compose([ GroupScale(int(scale_size)), GroupCenterCrop(crop_size), Stack(roll=False), ToTorchFormatTensor(div=True), normalize ])), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=False) # loss function criterion = nn.CrossEntropyLoss().cuda() for group in policies: print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format( group['name'], len(group['params']), group['lr_mult'], group['decay_mult']))) # tensorboard time_stamp = "{0:%Y-%m-%dT%H-%M-%S/}".format(datetime.now()) # train if args.mode == 'train': log_training = open( os.path.join(args.root_log, args.store_name, time_stamp, 'log.csv'), 'w') tf_writer = SummaryWriter( '{}/{}/'.format(args.root_log, args.store_name) + time_stamp) for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch, args.lr_steps, args.lr, args.weight_decay) train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer) # evaluate on validation set if (epoch + 1) % args.eval_freq == 0: prec1 = validate(val_loader, model, criterion, epoch, log_training, tf_writer) # remember best precision and save checkpoint is_best = prec1 >= best_prec1 best_prec1 = max(prec1, best_prec1) output_best = 'Best Prec@1: %.2f\n' % (best_prec1) print(output_best) log_training.write(output_best + '\n') log_training.flush() save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_prec1': best_prec1, }, is_best, args.root_model, args.store_name) tf_writer.close() # test checkpoint = '%s/%s/ckpt.best.pth.tar' % (args.root_model, args.store_name) test(test_loader, model, checkpoint, time_stamp)
extra_temporal_modeling = extra_temporal_modeling, prune_list = [prune_conv1in_list, prune_conv1out_list], is_prune = args.prune, ) ''' net = TSN( num_class, this_test_segments if is_shift else 1, modality, base_model=this_arch, new_length=2 if args.data_fuse else None, consensus_type=args.crop_fusion_type, #dropout=args.dropout, img_feature_dim=args.img_feature_dim, #partial_bn=not args.no_partialbn, pretrain=args.pretrain, is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place, #fc_lr5=not (args.tune_from and args.dataset in args.tune_from), #temporal_pool=args.temporal_pool, non_local='_nl' in this_weights, concat=concat, extra_temporal_modeling=extra_temporal_modeling, prune_list=[prune_conv1in_list, prune_conv1out_list], is_prune=args.prune, ) print(net) #print(args.shift) #exit() if 'tpool' in this_weights: from ops.temporal_shift import make_temporal_pool
thickness) return image if __name__ == '__main__': args = parser.parse_args() arch = 'resnet50' tsn = TSN(len(action_to_idx), args.num_segments, 'RGB', base_model=arch, consensus_type='avg', dropout=0.5, img_feature_dim=256, partial_bn=False, pretrain='imagenet', is_shift=True, shift_div=8, shift_place='blockres', fc_lr5=False, temporal_pool=False, non_local=False).to(args.device) model = torch.nn.DataParallel(tsn, device_ids=None).to(args.device) sd = torch.load(args.model, map_location=torch.device(args.device))['state_dict'] model.load_state_dict(sd) model.eval() meta = pd.DataFrame(columns=['action', 'time_start', 'time_end'])
def main(): global args, best_prec1, least_loss least_loss = 1000 args = parser.parse_args() if os.path.exists(os.path.join(args.root_log, "error.log")): os.remove(os.path.join(args.root_log, "error.log")) logging.basicConfig( level=logging.DEBUG, filename=os.path.join(args.root_log, "error.log"), filemode='a', format= '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' ) # log_handler = open(os.path.join(args.root_log,"error.log"),"w") # sys.stdout = log_handler if args.root_path: num_class, args.train_list, args.val_list, _, prefix = dataset_config.return_dataset( args.dataset, args.modality) args.train_list = os.path.join(args.root_log, "kf1_train_anno_lijun_iod.json") args.test_list = os.path.join(args.root_log, "kf1_test_anno_lijun_iod.json") else: num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset( args.dataset, args.modality) full_arch_name = args.arch if args.shift: full_arch_name += '_shift{}_{}'.format(args.shift_div, args.shift_place) if args.temporal_pool: full_arch_name += '_tpool' args.store_name = '_'.join([ 'TSA', args.dataset, args.modality, full_arch_name, args.consensus_type, 'segment%d' % args.num_segments, 'e{}'.format(args.epochs) ]) # if args.pretrain != 'imagenet': # args.store_name += '_{}'.format(args.pretrain) if args.lr_type != 'step': args.store_name += '_{}'.format(args.lr_type) if args.dense_sample: args.store_name += '_dense' if args.non_local > 0: args.store_name += '_nl' if args.suffix is not None: args.store_name += '_{}'.format(args.suffix) print('storing name: ' + args.store_name) check_rootfolders() model = TSN(num_class, args.num_segments, args.modality, base_model=args.arch, consensus_type=args.consensus_type, dropout=args.dropout, img_feature_dim=args.img_feature_dim, partial_bn=not args.no_partialbn, pretrain=args.pretrain, is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place, fc_lr5=not (args.tune_from and args.dataset in args.tune_from), temporal_pool=args.temporal_pool, non_local=args.non_local, is_TSA=args.tsa, is_sTSA=args.stsa, is_tTSA=args.ttsa, shift_diff=args.shift_diff, shift_groups=args.shift_groups, is_ME=args.me, is_3D=args.is_3D, cfg_file=args.cfg_file) crop_size = model.crop_size scale_size = model.scale_size input_mean = model.input_mean input_std = model.input_std # policies = model.get_optim_policies() train_augmentation = model.get_augmentation( flip=False if 'something' in args.dataset or 'jester' in args.dataset else True) model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda() if args.optimizer == "sgd": if args.lr_scheduler: optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) else: optimizer = torch.optim.SGD(policies, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) elif args.optimizer == "adam": params = get_vmz_fine_tuning_parameters(model, args.vmz_tune_last_k_layer) optimizer = torch.optim.Adam(params, args.lr, weight_decay=args.weight_decay) else: raise RuntimeError("not supported optimizer") if args.lr_scheduler: scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, args.lr_steps, args.lr_scheduler_gamma) if args.resume: if args.temporal_pool: # early temporal pool so that we can load the state_dict make_temporal_pool(model.module.base_model, args.num_segments) if os.path.isfile(args.resume): print(("=> loading checkpoint '{}'".format(args.resume))) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) # if args.lr_scheduler: # scheduler.load_state_dict(checkpoint["lr_scheduler"]) print(("=> loaded checkpoint '{}' (epoch {})".format( args.evaluate, checkpoint['epoch']))) logging.info(("=> loaded checkpoint '{}' (epoch {})".format( args.evaluate, checkpoint['epoch']))) else: print(("=> no checkpoint found at '{}'".format(args.resume))) logging.error( ("=> no checkpoint found at '{}'".format(args.resume))) if args.tune_from: print(("=> fine-tuning from '{}'".format(args.tune_from))) sd = torch.load(args.tune_from) sd = sd['state_dict'] model_dict = model.state_dict() replace_dict = [] for k, v in sd.items(): if k not in model_dict and k.replace('.net', '') in model_dict: print('=> Load after remove .net: ', k) replace_dict.append((k, k.replace('.net', ''))) for k, v in model_dict.items(): if k not in sd and k.replace('.net', '') in sd: print('=> Load after adding .net: ', k) replace_dict.append((k.replace('.net', ''), k)) for k, k_new in replace_dict: sd[k_new] = sd.pop(k) keys1 = set(list(sd.keys())) keys2 = set(list(model_dict.keys())) set_diff = (keys1 - keys2) | (keys2 - keys1) print('#### Notice: keys that failed to load: {}'.format(set_diff)) # sd = {k:v for k, v in sd.items() if k in keys2} sd = {k: v for k, v in sd.items() if k in keys2} if args.dataset not in args.tune_from: # new dataset print('=> New dataset, do not load fc weights') sd = { k: v for k, v in sd.items() if 'fc' not in k and "projection" not in k } if args.modality == 'Flow' and 'Flow' not in args.tune_from: sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k} model_dict.update(sd) model.load_state_dict(model_dict) if args.temporal_pool and not args.resume: make_temporal_pool(model.module.base_model, args.num_segments) cudnn.benchmark = True # Data loading code if args.modality != 'RGBDiff': normalize = GroupNormalize(input_mean, input_std) else: normalize = IdentityTransform() if args.modality in ['RGB', "PoseAction"]: data_length = 1 elif args.modality in ['Flow', 'RGBDiff']: data_length = 5 if not args.shuffle: train_loader = torch.utils.data.DataLoader( TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, transform=torchvision.transforms.Compose([ train_augmentation, Stack(roll=(args.arch in ['BNInception', 'InceptionV3']), inc_dim=(args.arch in ["R2plus1D", "X3D"])), ToTorchFormatTensor( div=(args.arch not in ['BNInception', 'InceptionV3']), inc_dim=(args.arch in ["R2plus1D", "X3D"])), normalize, ]), dense_sample=args.dense_sample, all_sample=args.all_sample), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) # prevent something not % n_GPU val_loader = torch.utils.data.DataLoader(TSNDataSet( args.root_path, args.val_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, random_shift=False, transform=torchvision.transforms.Compose([ GroupScale(scale_size), GroupCenterCrop(crop_size), Stack(roll=(args.arch in ['BNInception', 'InceptionV3']), inc_dim=(args.arch in ["R2plus1D", "X3D"])), ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3']), inc_dim=(args.arch in ["R2plus1D", "X3D"])), normalize, ]), dense_sample=args.dense_sample, all_sample=args.all_sample), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) # for group in policies: # print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format( # group['name'], len(group['params']), group['lr_mult'], group['decay_mult']))) if args.evaluate: if args.loss_type == 'nll': criterion = torch.nn.CrossEntropyLoss().cuda() elif args.loss_type == "bce": criterion = torch.nn.BCEWithLogitsLoss().cuda() elif args.loss_type == "wbce": class_weight, pos_weight = prep_weight(args.train_list) criterion = WeightedBCEWithLogitsLoss(class_weight, pos_weight) else: raise ValueError("Unknown loss type") val_loader = torch.utils.data.DataLoader(TSNDataSet( args.root_path, args.val_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, random_shift=False, transform=torchvision.transforms.Compose([ GroupScale(scale_size), GroupCenterCrop(crop_size), Stack(roll=(args.arch in ['BNInception', 'InceptionV3']), inc_dim=(args.arch in ["R2plus1D", "X3D"])), ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3']), inc_dim=(args.arch in ["R2plus1D", "X3D"])), normalize, ]), dense_sample=args.dense_sample, all_sample=args.all_sample, analysis=True), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) test(val_loader, model, criterion, 0) return log_training = open( os.path.join(args.root_log, args.store_name, 'log.csv'), 'w') with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f: f.write(str(args)) tf_writer = SummaryWriter( log_dir=os.path.join(args.root_log, args.store_name)) print(model) logging.info(model) for epoch in range(args.start_epoch, args.epochs): logging.info("Train Epoch {}/{} starts, estimated time 5832s".format( str(epoch), str(args.epochs))) # update data_loader if args.shuffle: gen_label(args.prop_path, args.label_path, args.trn_name, args.train_list, args.neg_rate, STR=False) gen_label(args.prop_path, args.label_path, args.tst_name, args.val_list, args.test_rate, STR=False) train_loader = torch.utils.data.DataLoader( TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, transform=torchvision.transforms.Compose([ train_augmentation, Stack(roll=(args.arch in ['BNInception', 'InceptionV3']), inc_dim=(args.arch in ["R2plus1D", "X3D"])), ToTorchFormatTensor( div=(args.arch not in ['BNInception', 'InceptionV3']), inc_dim=(args.arch in ["R2plus1D", "X3D"])), normalize, ]), dense_sample=args.dense_sample, all_sample=args.all_sample), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader( TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, random_shift=False, transform=torchvision.transforms.Compose([ GroupScale(scale_size), GroupCenterCrop(crop_size), Stack(roll=(args.arch in ['BNInception', 'InceptionV3']), inc_dim=(args.arch in ["R2plus1D", "X3D"])), ToTorchFormatTensor( div=(args.arch not in ['BNInception', 'InceptionV3']), inc_dim=(args.arch in ["R2plus1D", "X3D"])), normalize, ]), dense_sample=args.dense_sample, all_sample=args.all_sample), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) print(train_loader) # define loss function (criterion) and optimizer if args.loss_type == 'nll': criterion = torch.nn.CrossEntropyLoss().cuda() elif args.loss_type == "bce": criterion = torch.nn.BCEWithLogitsLoss().cuda() elif args.loss_type == "wbce": class_weight, pos_weight = prep_weight(args.train_list) criterion = WeightedBCEWithLogitsLoss(class_weight, pos_weight) else: raise ValueError("Unknown loss type") if not args.lr_scheduler: adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps) train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer) else: train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer) scheduler.step() # train for one epoch # evaluate on validation set if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: logging.info( "Test Epoch {}/{} starts, estimated time 13874s".format( str(epoch // args.eval_freq), str(args.epochs / args.eval_freq))) if args.loss_type == "wbce": # class_weight,pos_weight = prep_weight(args.val_list) criterion = torch.nn.BCEWithLogitsLoss().cuda() lossm = validate(val_loader, model, criterion, epoch, log_training, tf_writer) # remember best prec@1 and save checkpoint is_best = lossm < least_loss least_loss = min(lossm, least_loss) tf_writer.add_scalar('lss/test_top1_best', least_loss, epoch) output_best = 'Best Loss: %.3f\n' % (lossm) logging.info(output_best) log_training.write(output_best + '\n') log_training.flush() if args.lr_scheduler: save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_prec1': least_loss, 'lr_scheduler': scheduler, }, is_best, epoch) else: save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_prec1': least_loss, }, is_best, epoch)
for i, s in enumerate(strings): if 'shift' in s: break return True, int(strings[i].replace('shift', '')), strings[i + 1] else: return False, None, None is_shift, shift_div, shift_place = parse_shift_option_from_log_name(this_weights) print(is_shift, shift_div, shift_place) with torch.cuda.device(0): net = TSN(2, 1, 'RGB', base_model=this_arch, consensus_type='avg', img_feature_dim='225', #pretrain=args.pretrain, is_shift=is_shift, shift_div=shift_div, shift_place=shift_place, non_local='_nl' in this_weights, ) macs, params = get_model_complexity_info(net, (1,3, 224, 224), as_strings=True,print_per_layer_stat=False, verbose=False) print("Using ptflops") print('{:<30} {:<8}'.format('Computational complexity: ', macs)) print('{:<30} {:<8}'.format('Number of parameters: ', params)) from thop import profile model = net = TSN(2, 1, 'RGB', base_model=this_arch, consensus_type='avg', img_feature_dim='225', #pretrain=args.pretrain,
def doInferecing(cap, args, GPU_FLAG): # switch between archs based on selected arch if args.get("arch") == "mobilenetv2": this_weights = "checkpoint/TSM_ucfcrime_RGB_mobilenetv2_shift8_blockres_avg_segment8_e25/ckpt.best.pth.tar" else: this_weights = "checkpoint/TSM_ucfcrime_RGB_resnet50_shift8_blockres_avg_segment8_e25/ckpt.best.pth.tar" is_shift, shift_div, shift_place = parse_shift_option_from_log_name( this_weights) modality = "RGB" if "RGB" in this_weights: modality = "RGB" # Get dataset categories. categories = ["Normal Activity", "Abnormal Activity"] num_class = len(categories) this_arch = args.get("arch") print("[INFO] >> Model loading weights from disk!!") net = TSN( num_class, 1, modality, base_model=this_arch, consensus_type="avg", img_feature_dim="225", # pretrain=args.pretrain, is_shift=is_shift, shift_div=shift_div, shift_place=shift_place, non_local="_nl" in this_weights, ) # See GPU_FLAG to check where to load the weights on CPU or GPU if GPU_FLAG == "y": checkpoint = torch.load(this_weights) else: checkpoint = torch.load(this_weights, map_location=torch.device("cpu")) checkpoint = checkpoint["state_dict"] base_dict = { ".".join(k.split(".")[1:]): v for k, v in list(checkpoint.items()) } replace_dict = { "base_model.classifier.weight": "new_fc.weight", "base_model.classifier.bias": "new_fc.bias", } for k, v in replace_dict.items(): if k in base_dict: base_dict[v] = base_dict.pop(k) net.load_state_dict(base_dict) print("\n[INFO] >> Model loading Successfull") if GPU_FLAG == "y": net.cuda().eval() skip_frames = 2 summary(net, (1, 3, 224, 224)) else: net.eval() skip_frames = 4 transform = torchvision.transforms.Compose([ Stack(roll=(this_arch in ["BNInception", "InceptionV3"])), ToTorchFormatTensor( div=(this_arch not in ["BNInception", "InceptionV3"])), GroupNormalize(net.input_mean, net.input_std), ]) WINDOW_NAME = "Real-Time Video Action Recognition" # set a lower resolution for speed up cap.set(cv2.CAP_PROP_FRAME_WIDTH, 320) cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 240) # env variables full_screen = False cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL) cv2.resizeWindow(WINDOW_NAME, 640, 480) cv2.moveWindow(WINDOW_NAME, 0, 0) cv2.setWindowTitle(WINDOW_NAME, WINDOW_NAME) t = None i_frame = -1 count = 0 imageName = 0 # variable to hold writer object writer = None c = 0 print("Ready!") while cap.isOpened(): i_frame += 1 hasFrame, img = cap.read() # (480, 640, 3) 0 ~ 255 if hasFrame: img_tran = transform([Image.fromarray(img).convert("RGB")]) if (i_frame % skip_frames == 0 ): # skip every other frame to obtain a suitable frame rate t1 = time.time() if GPU_FLAG == "y": input1 = (img_tran.view( -1, 3, img_tran.size(1), img_tran.size(2)).unsqueeze(0).cuda()) else: input1 = img_tran.view(-1, 3, img_tran.size(1), img_tran.size(2)).unsqueeze(0) input = input1 with torch.no_grad(): logits = net(input) h_x = torch.mean(F.softmax(logits, 1), dim=0).data print( "<<< [INFO] >>> PROB - | Normal: {:.2f}".format( h_x[0]), "| Abnormal: {:.2f} |".format(h_x[1]), "Frames Rendered-", count, ) pr, li = h_x.sort(0, True) probs = pr.tolist() idx = li.tolist() # print(probs) t2 = time.time() print( "<<< [INFO] >>>", "EVENT - |", categories[idx[0]], " Prob: {:.2f}| ".format(probs[0]), "\n", ) current_time = t2 - t1 img = cv2.resize(img, (640, 480)) height, width, _ = img.shape if categories[idx[0]] == "Abnormal Activity": R = 255 G = 0 Abnormality = True tempThres = probs[0] c += 1 maxAbnormalProb.append(float(probs[0])) else: R = 0 G = 255 Abnormality = False cv2.putText( img, "EVENT: " + categories[idx[0]], (20, int(height / 16)), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, int(G), int(R)), 2, ) cv2.putText( img, "Confidence: {0:.2f}%".format(probs[0] * 100, "%"), (20, int(height - 420)), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, int(G), int(R)), 2, ) fps = 1 / current_time # if args.get('f',True): FpsList.append(float(fps)) maxFps = max(FpsList) estFps = sum(FpsList) / len(FpsList) # else: # maxFps=-1 # estFps=-1 cv2.putText( img, "FPS: {0:.1f}".format(fps), (width - 150, int(height / 16)), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, ) if writer is None: fourcc = cv2.VideoWriter_fourcc(*"MJPG") (H, W) = img.shape[:2] path = "./appData/Anoamly_Clips/" name = len(glob.glob(path + "*.avi")) getVidName = path + "Abnormal_Event_{0}.avi".format(name + 1) writer = cv2.VideoWriter(getVidName, fourcc, 30.0, (W, H), True) # Saving Anaomlous Event Image and Clip if Abnormality: writer.write(img) # record stat every two seconds if exists if c % 60 == 0: getStatsOfAbnormalActivity() # if tempThres > 0.75: path = "./appData/Anoamly_Images/" index = len(glob.glob(path + "*.jpg")) # imageName = getFileName(path+'.jpg') imageName = path + "Abnormal_Event_{0}.jpg".format(index + 1) cv2.imwrite(imageName, img) cv2.imshow(WINDOW_NAME, img) key = cv2.waitKey(1) if key & 0xFF == ord("q") or key == 27: # exit break elif key == ord("F") or key == ord("f"): # full screen print("Changing full screen option!") full_screen = not full_screen if full_screen: print("Setting FS!!!") cv2.setWindowProperty(WINDOW_NAME, cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN) else: cv2.setWindowProperty(WINDOW_NAME, cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_NORMAL) # resetting time for next frame if t is None: t = time.time() else: nt = time.time() count += 1 t = nt else: # Uncomment below lines to run code unfinitely and comment cap.release and writer,release # i_frame = 0 # cap.set(cv2.CAP_PROP_POS_FRAMES,0) cap.release() writer.release() cv2.destroyAllWindows() # Clearing Variables for re-running # estFps=None # maxAbnormalProb.clear() # maxFps=None # Calculating total execution time execTime = time.time() - startime print() # Display Results print("<<< [INFO] >>> Total Abnormal Probs : ", len(maxAbnormalProb)) print("<<< [INFO] >>> Max Abnormality Prob : {:.2f}".format( max(maxAbnormalProb))) print("<<< [INFO] >>> Avg Abnormality Prob : {:.2f}".format( sum(maxAbnormalProb) / len(maxAbnormalProb))) print("<<< [INFO] >>> Max FPS achieved : {:.1f}".format(maxFps)) print("<<< [INFO] >>> Averge Estimated FPS : {:.1f}".format(estFps)) print("<<< [INFO] >>> Total Infernece Time : {:.2f} seconds".format( execTime))
def main(): global args, best_prec1 global crop_size args = parser.parse_args() num_class, train_list, val_list, args.root_path, prefix = dataset_config.return_dataset( args.dataset, args.modality) num_class = 1 if args.train_list == "": args.train_list = train_list if args.val_list == "": args.val_list = val_list full_arch_name = args.arch if args.shift: full_arch_name += '_shift{}_{}'.format(args.shift_div, args.shift_place) if args.concat != "": full_arch_name += '_concat{}'.format(args.concat) if args.temporal_pool: full_arch_name += '_tpool' args.store_name = '_'.join([ 'TSM', args.dataset, args.modality, full_arch_name, args.consensus_type, 'lr%.5f' % args.lr, 'dropout%.2f' % args.dropout, 'wd%.5f' % args.weight_decay, 'batch%d' % args.batch_size, 'segment%d' % args.num_segments, 'e{}'.format(args.epochs) ]) if args.data_fuse: args.store_name += '_fuse' if args.extra_temporal_modeling: args.store_name += '_extra' if args.tune_from is not None: args.store_name += '_finetune' if args.pretrain != 'imagenet': args.store_name += '_{}'.format(args.pretrain) if args.lr_type != 'step': args.store_name += '_{}'.format(args.lr_type) if args.dense_sample: args.store_name += '_dense' if args.non_local > 0: args.store_name += '_nl' if args.clipnums: #pass args.store_name += "_clip{}".format(args.clipnums) if args.suffix is not None: args.store_name += '_{}'.format(args.suffix) print('storing name: ' + args.store_name) check_rootfolders() if args.prune in ['input', 'inout'] and args.tune_from: sd = torch.load(args.tune_from) sd = sd['state_dict'] sd = input_dim_L2distance(sd, args.shift_div) model = TSN( num_class, args.num_segments, args.modality, base_model=args.arch, new_length=2 if args.data_fuse else None, consensus_type=args.consensus_type, dropout=args.dropout, img_feature_dim=args.img_feature_dim, partial_bn=not args.no_partialbn, pretrain=args.pretrain, is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place, fc_lr5=not (args.tune_from and args.dataset in args.tune_from), temporal_pool=args.temporal_pool, non_local=args.non_local, concat=args.concat, extra_temporal_modeling=args.extra_temporal_modeling, prune_list=[prune_conv1in_list, prune_conv1out_list], is_prune=args.prune, ) #model = torch.load("/home/ubuntu/backup_kevin/myownTSM_git/checkpoint/TSM_youcook_RGB_resnet50_shift8_blockres_concatAll_conv1d_lr0.00025_dropout0.70_wd0.00050_batch16_segment8_e20_finetune_slice_v1_clipnum500/ckpt_"+str(1)+".pth.tar") print(model) #summary(model, torch.zeros((16, 24, 224, 224))) #exit(1) if args.dataset == 'ucf101': #twice sample & full resolution twice_sample = True crop_size = model.scale_size #256 x 256 else: twice_sample = False crop_size = model.crop_size #224 x 224 crop_size = 256 scale_size = model.scale_size input_mean = model.input_mean input_std = model.input_std policies = model.get_optim_policies(args.concat) train_augmentation = model.get_augmentation( flip=False if 'something' in args.dataset or 'jester' in args.dataset or 'nvgesture' in args.dataset else True) model = torch.nn.DataParallel(model).cuda() optimizer = torch.optim.SGD(policies, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.resume: if args.temporal_pool: # early temporal pool so that we can load the state_dict make_temporal_pool(model.module.base_model, args.num_segments) if os.path.isfile(args.resume): print(("=> loading checkpoint '{}'".format(args.resume))) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print(("=> loaded checkpoint '{}' (epoch {})".format( args.evaluate, checkpoint['epoch']))) else: print(("=> no checkpoint found at '{}'".format(args.resume))) if args.tune_from: print(("=> fine-tuning from '{}'".format(args.tune_from))) tune_from_list = args.tune_from.split(',') sd = torch.load(tune_from_list[0]) sd = sd['state_dict'] model_dict = model.state_dict() replace_dict = [] for k, v in sd.items(): if k not in model_dict and k.replace('.net', '') in model_dict: print('=> Load after remove .net: ', k) replace_dict.append((k, k.replace('.net', ''))) for k, v in model_dict.items(): if k not in sd and k.replace('.net', '') in sd: print('=> Load after adding .net: ', k) replace_dict.append((k.replace('.net', ''), k)) for k, v in model_dict.items(): if k not in sd and k.replace('.prune', '') in sd: print('=> Load after adding .prune: ', k) replace_dict.append((k.replace('.prune', ''), k)) if args.prune in ['input', 'inout']: sd = adjust_para_shape_prunein(sd, model_dict) if args.prune in ['output', 'inout']: sd = adjust_para_shape_pruneout(sd, model_dict) if args.concat != "" and "concat" not in tune_from_list[0]: sd = adjust_para_shape_concat(sd, model_dict) for k, k_new in replace_dict: sd[k_new] = sd.pop(k) keys1 = set(list(sd.keys())) keys2 = set(list(model_dict.keys())) set_diff = (keys1 - keys2) | (keys2 - keys1) print('#### Notice: keys that failed to load: {}'.format(set_diff)) if args.dataset not in tune_from_list[0]: # new dataset print('=> New dataset, do not load fc weights') sd = {k: v for k, v in sd.items() if 'fc' not in k} if args.modality != 'Flow' and 'Flow' in tune_from_list[0]: sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k} #print(sd.keys()) #print("*"*50) #print(model_dict.keys()) model_dict.update(sd) model.load_state_dict(model_dict) if args.temporal_pool and not args.resume: make_temporal_pool(model.module.base_model, args.num_segments) cudnn.benchmark = True # Data loading code if args.modality != 'RGBDiff': normalize = GroupNormalize(input_mean, input_std) else: normalize = IdentityTransform() if args.modality in ['RGB']: data_length = 1 elif args.modality in ['Depth']: data_length = 1 elif args.modality in ['Flow', 'RGBDiff']: data_length = 5 ''' dataRoot = r"/home/share/YouCook/downloadVideo" for dirPath, dirnames, filenames in os.walk(dataRoot): for filename in filenames: print(os.path.join(dirPath,filename) +"is {}".format("exist" if os.path.isfile(os.path.join(dirPath,filename))else "NON")) train_data = torchvision.io.read_video(os.path.join(dirPath,filename),start_pts=0,end_pts=1001, ) tmp = torchvision.io.read_video_timestamps(os.path.join(dirPath,filename),) print(tmp) print(len(tmp[0])) print(train_data[0].size()) exit() exit() ''' ''' train_loader = torch.utils.data.DataLoader( TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, transform=torchvision.transforms.Compose([ train_augmentation, Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize, ]), dense_sample=args.dense_sample, data_fuse = args.data_fuse), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) # prevent something not % n_GPU val_loader = torch.utils.data.DataLoader( TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, random_shift=False, transform=torchvision.transforms.Compose([ GroupScale(int(scale_size)), GroupCenterCrop(crop_size), Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize, ]), dense_sample=args.dense_sample, twice_sample=twice_sample, data_fuse = args.data_fuse), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) ''' #global trainDataloader, valDataloader, train_loader, val_loader trainDataloader = YouCookDataSetRcg(args.root_path, args.train_list,train=True,inputsize=crop_size,hasPreprocess = False,\ clipnums=args.clipnums, hasWordIndex = True,) valDataloader = YouCookDataSetRcg(args.root_path, args.val_list,val=True,inputsize=crop_size,hasPreprocess = False,\ clipnums=args.clipnums, hasWordIndex = True,) #print(trainDataloader._getMode()) #print(valDataloader._getMode()) #exit() train_loader = torch.utils.data.DataLoader(trainDataloader, #shuffle=True, ) val_loader = torch.utils.data.DataLoader(valDataloader) #print(train_loader is val_loader) #print(trainDataloader._getMode()) #print(valDataloader._getMode()) #print(trainDataloader._getMode()) #print(valDataloader._getMode()) #print(len(train_loader)) #exit() # define loss function (criterion) and optimizer if args.loss_type == 'nll': criterion = torch.nn.CrossEntropyLoss().cuda() elif args.loss_type == "MSELoss": criterion = torch.nn.MSELoss().cuda() elif args.loss_type == "BCELoss": #print("BCELoss") criterion = torch.nn.BCELoss().cuda() else: raise ValueError("Unknown loss type") for group in policies: print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format( group['name'], len(group['params']), group['lr_mult'], group['decay_mult']))) if args.evaluate: validate(val_loader, model, criterion, 0) return log_training = open( os.path.join(args.root_log, args.store_name, 'log.csv'), 'w') with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f: f.write(str(args)) #print(os.path.join(args.root_log, args.store_name, 'args.txt')) #exit() tf_writer = SummaryWriter( log_dir=os.path.join(args.root_log, args.store_name)) for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps) #print("265") # train for one epoch ###### #print(trainDataloader._getMode()) #print(valDataloader._getMode()) train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer) ###### #print("268") # evaluate on validation set #model = model.load_state_dict(torch.load("/home/ubuntu/backup_kevin/myownTSM_git/checkpoint/TSM_youcook_RGB_resnet50_shift8_blockres_concatAll_conv1d_lr0.00025_dropout0.70_wd0.00050_batch16_segment8_e20_finetune_slice_v1_clipnum500/ckpt_"+str(epoch+1)+".pth.tar")) #if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: if False: prec1 = validate(val_loader, model, criterion, epoch, log_training, tf_writer) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch) output_best = 'Best Prec@1: %.3f\n' % (best_prec1) #print(output_best) log_training.write(output_best + '\n') log_training.flush() save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_prec1': best_prec1, }, is_best) else: save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_prec1': best_prec1, }, False) #break print("test pass")
modality = 'RGB' else: modality = 'Flow' this_arch = this_weights.split('TSM_')[1].split('_')[2] modality_list.append(modality) num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset( args.dataset, modality) print('=> shift: {}, shift_div: {}, shift_place: {}'.format( is_shift, shift_div, shift_place)) net = TSN( num_class, this_test_segments if is_shift else 1, modality, base_model=this_arch, consensus_type=args.crop_fusion_type, img_feature_dim=args.img_feature_dim, pretrain=args.pretrain, is_shift=is_shift, shift_div=shift_div, shift_place=shift_place, non_local='_nl' in this_weights, ) if 'tpool' in this_weights: from ops.temporal_shift import make_temporal_pool make_temporal_pool(net.base_model, this_test_segments) # since DataParallel checkpoint = torch.load(this_weights) checkpoint = checkpoint['state_dict']
total_num = None for this_weights, this_test_segments, test_file in zip(weights_list, test_segments_list, test_file_list): has_tam, modality, backbone = parse_shift_option_from_log_name( this_weights) modality_list.append(modality) num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset( args.dataset, modality) print('=> TAM : {}, {} dense'.format(has_tam, args.sample)) net = TSN( num_class, this_test_segments if has_tam else 1, modality, base_model=backbone, consensus_type=args.crop_fusion_type, img_feature_dim=args.img_feature_dim, tam=has_tam, non_local='_nl' in this_weights, ) checkpoint = torch.load(this_weights, map_location='cpu') checkpoint = checkpoint['state_dict'] base_dict = {} for k, v in list(checkpoint.items()): if k.startswith('module'): base_dict['.'.join(k.split('.')[1:])] = v else: base_dict[k] = v net.load_state_dict(base_dict)
def main(): global args, best_prec1 args = parser.parse_args() num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset, args.modality) model = TSN(num_class, args.num_segments, args.modality, base_model=args.arch, consensus_type=args.consensus_type, dropout=args.dropout, img_feature_dim=args.img_feature_dim, partial_bn=not args.no_partialbn, pretrain=args.pretrain, is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place, fc_lr5=not (args.tune_from and args.dataset in args.tune_from), temporal_pool=args.temporal_pool, non_local=args.non_local, cca3d = args.cca3d ) crop_size = model.crop_size scale_size = model.scale_size input_mean = model.input_mean input_std = model.input_std policies = model.get_optim_policies() train_augmentation = model.get_augmentation(flip=False if 'something' in args.dataset or 'jester' in args.dataset else True) model = model.cuda() if args.resume: if args.temporal_pool: # early temporal pool so that we can load the state_dict make_temporal_pool(model.module.base_model, args.num_segments) if os.path.isfile(args.resume): print(("=> loading checkpoint '{}'".format(args.resume))) checkpoint = torch.load(args.resume,map_location='cuda:0') parallel_state_dict = checkpoint['state_dict'] cpu_state_dict={} for k,v in parallel_state_dict.items(): cpu_state_dict[k[len('module.'):]] = v model.load_state_dict(cpu_state_dict) print(("=> loaded checkpoint '{}' (epoch {})".format(args.evaluate, checkpoint['epoch']))) else: print(("=> no checkpoint found at '{}'".format(args.resume))) # Data loading code if args.modality != 'RGBDiff': normalize = GroupNormalize(input_mean, input_std) else: normalize = IdentityTransform() if args.modality == 'RGB': data_length = 1 elif args.modality in ['Flow', 'RGBDiff']: data_length = 5 preprocess=torchvision.transforms.Compose([ GroupScale(int(scale_size)), GroupCenterCrop(crop_size), Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize, ]) val_loader = torch.utils.data.DataLoader( TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, random_shift=False, transform = preprocess, dense_sample=args.dense_sample), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=False) norm_param = (input_mean, input_std) cam_process(val_loader, model,norm_param)
rst = rst.reshape( (batch_size, -1, num_class)).mean(axis=1).reshape( (batch_size, num_class)) return rst num_class, args.train_list, val_list, prefix = dataset_config.return_dataset( args.dataset, args.modality) net = TSN( num_class, args.test_segments, args.modality, base_model=args.arch, consensus_type=args.crop_fusion_type, img_feature_dim=args.img_feature_dim, pretrain=args.pretrain, is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place, ) # import pdb; pdb.set_trace() ''' checkpoint = torch.load(args.weight) checkpoint = checkpoint['state_dict'] base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())} replace_dict = { 'base_model.classifier.weight': 'new_fc.weight', 'base_model.classifier.bias': 'new_fc.bias',
def main(): global args, best_prec1 args = parser.parse_args() #num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset, # args.modality) num_class = 21 args.train_list = "/home/jzwang/code/Video_3D/movienet/data/movie/movie_train.txt" args.val_list = "/home/jzwang/code/Video_3D/movienet/data/movie/movie_test.txt" args.root_path = "" prefix = "frame_{:04d}.jpg" full_arch_name = args.arch if args.shift: full_arch_name += '_shift{}_{}'.format(args.shift_div, args.shift_place) if args.temporal_pool: full_arch_name += '_tpool' args.store_name = '_'.join( ['TSM', args.dataset, args.modality, full_arch_name, args.consensus_type, 'segment%d' % args.num_segments, 'e{}'.format(args.epochs)]) if args.pretrain != 'imagenet': args.store_name += '_{}'.format(args.pretrain) if args.lr_type != 'step': args.store_name += '_{}'.format(args.lr_type) if args.dense_sample: args.store_name += '_dense' if args.non_local > 0: args.store_name += '_nl' if args.suffix is not None: args.store_name += '_{}'.format(args.suffix) print('storing name: ' + args.store_name) #check_rootfolders() model = TSN(num_class, args.num_segments, args.modality, base_model=args.arch, consensus_type=args.consensus_type, dropout=args.dropout, img_feature_dim=args.img_feature_dim, partial_bn=not args.no_partialbn, pretrain=args.pretrain, is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place, fc_lr5=not (args.tune_from and args.dataset in args.tune_from), temporal_pool=args.temporal_pool, non_local=args.non_local) crop_size = model.crop_size scale_size = model.scale_size input_mean = model.input_mean input_std = model.input_std policies = model.get_optim_policies() train_augmentation = model.get_augmentation(flip=False if 'something' in args.dataset or 'jester' in args.dataset else True) model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda() optimizer = torch.optim.SGD(policies, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.resume: if args.temporal_pool: # early temporal pool so that we can load the state_dict make_temporal_pool(model.module.base_model, args.num_segments) if os.path.isfile(args.resume): print(("=> loading checkpoint '{}'".format(args.resume))) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] #best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print(("=> loaded checkpoint '{}' (epoch {})" .format(args.evaluate, checkpoint['epoch']))) else: print(("=> no checkpoint found at '{}'".format(args.resume))) if args.tune_from: print(("=> fine-tuning from '{}'".format(args.tune_from))) sd = torch.load(args.tune_from) sd = sd['state_dict'] model_dict = model.state_dict() replace_dict = [] for k, v in sd.items(): if k not in model_dict and k.replace('.net', '') in model_dict: print('=> Load after remove .net: ', k) replace_dict.append((k, k.replace('.net', ''))) for k, v in model_dict.items(): if k not in sd and k.replace('.net', '') in sd: print('=> Load after adding .net: ', k) replace_dict.append((k.replace('.net', ''), k)) for k, k_new in replace_dict: sd[k_new] = sd.pop(k) keys1 = set(list(sd.keys())) keys2 = set(list(model_dict.keys())) set_diff = (keys1 - keys2) | (keys2 - keys1) print('#### Notice: keys that failed to load: {}'.format(set_diff)) if args.dataset not in args.tune_from: # new dataset print('=> New dataset, do not load fc weights') sd = {k: v for k, v in sd.items() if 'fc' not in k} if args.modality == 'Flow' and 'Flow' not in args.tune_from: sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k} model_dict.update(sd) model.load_state_dict(model_dict) if args.temporal_pool and not args.resume: make_temporal_pool(model.module.base_model, args.num_segments) cudnn.benchmark = True # Data loading code if args.modality != 'RGBDiff': normalize = GroupNormalize(input_mean, input_std) else: normalize = IdentityTransform() if args.modality == 'RGB': data_length = 1 elif args.modality in ['Flow', 'RGBDiff']: data_length = 5 train_loader = torch.utils.data.DataLoader( TSNDataSetMovie("", args.train_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl="frame_{:04d}.jpg" if args.modality in ["RGB", "RGBDiff"] else args.flow_prefix+"{}_{:05d}.jpg", transform=torchvision.transforms.Compose([ train_augmentation, Stack(roll=args.arch == 'BNInception'), ToTorchFormatTensor(div=args.arch != 'BNInception'), normalize, ])), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader( TSNDataSetMovie("", args.val_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl="frame_{:04d}.jpg" if args.modality in ["RGB", "RGBDiff"] else args.flow_prefix+"{}_{:05d}.jpg", random_shift=False, transform=torchvision.transforms.Compose([ GroupScale(int(scale_size)), GroupCenterCrop(crop_size), Stack(roll=args.arch == 'BNInception'), ToTorchFormatTensor(div=args.arch != 'BNInception'), normalize, ])), batch_size=int(args.batch_size/2), shuffle=False, num_workers=args.workers, pin_memory=True) # define loss function (criterion) and optimizer criterion = torch.nn.BCEWithLogitsLoss().cuda() for group in policies: print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format( group['name'], len(group['params']), group['lr_mult'], group['decay_mult']))) optimizer = torch.optim.SGD(policies, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) zero_time = time.time() best_map = 0 print ('Start training...') for epoch in range(args.start_epoch, args.epochs): valloss, mAP, wAP, output_mtx = validate(val_loader, model, criterion) adjust_learning_rate(optimizer, epoch, args.lr_steps) np.save("testnew.npy", output_mtx) print("saving down") # train for one epoch start_time = time.time() trainloss = train(train_loader, model, criterion, optimizer, epoch) print('Traing loss %4f Epoch %d'% (trainloss, epoch)) if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: valloss, mAP, wAP, output_mtx = validate(val_loader, model, criterion) end_time = time.time() epoch_time = end_time - start_time total_time = end_time - zero_time print ('Total time used: %s Epoch %d time uesd: %s'%( str(datetime.timedelta(seconds=int(total_time))), epoch, str(datetime.timedelta(seconds=int(epoch_time))))) print ('Train loss: {0:.4f} val loss: {1:.4f} mAP: {2:.4f} wAP: {3:.4f}'.format( trainloss, valloss, mAP, wAP)) # evaluate on validation set is_best = mAP > best_map #if mAP > best_map: #best_map = mAP # checkpoint_name = "%04d_%s" % (epoch+1, "checkpoint.pth.tar") checkpoint_name = "best_checkpoint.pth.tar" save_checkpoint({ 'epoch': epoch+1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, is_best, epoch) np.save("testnew.npy", output_mtx) print("saving down") with open(args.record_path, 'a') as file: file.write('Epoch:[{0}]' 'Train loss: {1:.4f} val loss: {2:.4f} map: {3:.4f}\n'.format( epoch+1, trainloss, valloss, mAP)) print ('************ Done!... ************')
data_iter_list = [] net_list = [] modality_list = args.modalities.split(',') arch_list = args.archs.split('.') total_num = None for this_weights, this_test_segments, test_file, modality, this_arch in zip( weights_list, test_segments_list, test_file_list, modality_list, arch_list): num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset( args.dataset, modality) net = TSN(num_class, this_test_segments, modality, base_model=this_arch, consensus_type=args.crop_fusion_type, img_feature_dim=args.img_feature_dim, pretrain=args.pretrain) checkpoint = torch.load(this_weights) try: net.load_state_dict(checkpoint['state_dict']) except: checkpoint = checkpoint['state_dict'] base_dict = { '.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items()) } replace_dict = { 'base_model.classifier.weight': 'new_fc.weight',
def __init__(self, checkpoint_file, num_classes, max_length=8, trim_net=False, checkpoint_is_model=False, bottleneck_size=128): self.is_shift = None self.net = None self.arch = None self.num_classes = num_classes self.max_length = max_length self.bottleneck_size = bottleneck_size #self.feature_idx = feature_idx self.transform = None self.CNN_FEATURE_COUNT = [256, 512, 1024, 2048] # input variables this_test_segments = self.max_length test_file = None #model variables self.is_shift, shift_div, shift_place = True, 8, 'blockres' self.arch = 'resnet101' modality = 'RGB' # dataset variables num_class, train_list, val_list, root_path, prefix = dataset_config.return_dataset( 'somethingv2', modality) print('=> shift: {}, shift_div: {}, shift_place: {}'.format( self.is_shift, shift_div, shift_place)) # define model net = TSN( num_class, this_test_segments if self.is_shift else 1, modality, base_model=self.arch, consensus_type='avg', img_feature_dim=256, pretrain='imagenet', is_shift=self.is_shift, shift_div=shift_div, shift_place=shift_place, non_local='_nl' in checkpoint_file, ) ''' The checkpoint file appears to be an entire TSMBackBone Object. this needs to be handled acordingly. Either find a way to convert it back to a weights file or maniuplate it to work with the system. ''' # load checkpoint file checkpoint = torch.load(checkpoint_file) ''' #include print("self.bottleneck_size:", self.bottleneck_size, type(self.bottleneck_size)) net.base_model.avgpool = nn.Sequential( nn.Conv2d(2048, self.bottleneck_size, (1,1)), nn.ReLU(inplace=True), #nn.AdaptiveAvgPool2d(output_size=1) ) if(not trim_net): print("no trim") net.new_fc = nn.Linear(self.bottleneck_size, 174) else: print("trim") net.consensus = nn.Identity() net.new_fc = nn.Identity() net.base_model.fc = nn.Identity() # sets the dropout value to None print(net) # Combine network together so that the it can have parameters set correctly # I think, I'm not 100% what this code section actually does and I don't have # the time to figure it out right now #print("checkpoint------------------------") #print(checkpoint) ''' if (checkpoint_is_model): checkpoint = checkpoint.net.state_dict() else: checkpoint = checkpoint['state_dict'] base_dict = { '.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items()) } ''' #include replace_dict = {'base_model.classifier.weight': 'new_fc.weight', 'base_model.classifier.bias': 'new_fc.bias', } for k, v in replace_dict.items(): if v in base_dict: base_dict.pop(v) if k in base_dict: base_dict.pop(k) #base_dict[v] = base_dict.pop(k) ''' net.load_state_dict(base_dict, strict=False) # define image modifications self.transform = torchvision.transforms.Compose([ torchvision.transforms.Compose([ GroupScale(net.scale_size), GroupCenterCrop(net.scale_size), ]), #torchvision.transforms.Compose([ GroupFullResSample(net.scale_size, net.scale_size, flip=False) ]), Stack(roll=(self.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor( div=(self.arch not in ['BNInception', 'InceptionV3'])), GroupNormalize(net.input_mean, net.input_std), ]) # place net onto GPU and finalize network self.model = net net = torch.nn.DataParallel(net.cuda()) net.eval() # network variable self.net = net # loss variable (used for generating gradients when ranking) if (not trim_net): self.loss = torch.nn.CrossEntropyLoss().cuda()
def main(): global args, best_prec1 args = parser.parse_args() num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset( args.dataset, args.modality) full_arch_name = args.arch if args.shift: full_arch_name += '_shift{}_{}'.format(args.shift_div, args.shift_place) if args.temporal_pool: full_arch_name += '_tpool' args.store_name = '_'.join([ 'TSM', args.dataset, args.modality, full_arch_name, args.consensus_type, 'segment%d' % args.num_segments, 'e{}'.format(args.epochs) ]) if args.pretrain != 'imagenet': args.store_name += '_{}'.format(args.pretrain) if args.lr_type != 'step': args.store_name += '_{}'.format(args.lr_type) if args.dense_sample: args.store_name += '_dense' if args.non_local > 0: args.store_name += '_nl' if args.suffix is not None: args.store_name += '_{}'.format(args.suffix) print('storing name: ' + args.store_name) check_rootfolders() model = TSN(num_class, args.num_segments, args.modality, base_model=args.arch, consensus_type=args.consensus_type, dropout=args.dropout, img_feature_dim=args.img_feature_dim, partial_bn=not args.no_partialbn, pretrain=args.pretrain, is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place, fc_lr5=not (args.tune_from and args.dataset in args.tune_from), temporal_pool=args.temporal_pool, non_local=args.non_local) crop_size = model.crop_size scale_size = model.scale_size input_mean = model.input_mean input_std = model.input_std policies = model.get_optim_policies() train_augmentation = model.get_augmentation( flip=False if 'something' in args.dataset or 'jester' in args.dataset else True) model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda() optimizer = torch.optim.SGD(policies, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.resume: if args.temporal_pool: # early temporal pool so that we can load the state_dict make_temporal_pool(model.module.base_model, args.num_segments) if os.path.isfile(args.resume): print(("=> loading checkpoint '{}'".format(args.resume))) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print(("=> loaded checkpoint '{}' (epoch {})".format( args.evaluate, checkpoint['epoch']))) else: print(("=> no checkpoint found at '{}'".format(args.resume))) if args.tune_from: print(("=> fine-tuning from '{}'".format(args.tune_from))) sd = torch.load(args.tune_from) sd = sd['state_dict'] model_dict = model.state_dict() replace_dict = [] for k, v in sd.items(): if k not in model_dict and k.replace('.net', '') in model_dict: print('=> Load after remove .net: ', k) replace_dict.append((k, k.replace('.net', ''))) for k, v in model_dict.items(): if k not in sd and k.replace('.net', '') in sd: print('=> Load after adding .net: ', k) replace_dict.append((k.replace('.net', ''), k)) for k, k_new in replace_dict: sd[k_new] = sd.pop(k) keys1 = set(list(sd.keys())) keys2 = set(list(model_dict.keys())) set_diff = (keys1 - keys2) | (keys2 - keys1) print('#### Notice: keys that failed to load: {}'.format(set_diff)) if args.dataset not in args.tune_from: # new dataset print('=> New dataset, do not load fc weights') sd = {k: v for k, v in sd.items() if 'fc' not in k} if args.modality == 'Flow' and 'Flow' not in args.tune_from: sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k} model_dict.update(sd) model.load_state_dict(model_dict) if args.temporal_pool and not args.resume: make_temporal_pool(model.module.base_model, args.num_segments) cudnn.benchmark = True # Data loading code if args.modality != 'RGBDiff': normalize = GroupNormalize(input_mean, input_std) else: normalize = IdentityTransform() if args.modality == 'RGB': data_length = 1 elif args.modality in ['Flow', 'RGBDiff']: data_length = 5 train_loader = torch.utils.data.DataLoader( TSNDataSet( args.root_path, args.train_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, transform=torchvision.transforms.Compose([ train_augmentation, Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor( div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize, ]), dense_sample=args.dense_sample), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) # prevent something not % n_GPU val_loader = torch.utils.data.DataLoader(TSNDataSet( args.root_path, args.val_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, random_shift=False, transform=torchvision.transforms.Compose([ GroupScale(int(scale_size)), GroupCenterCrop(crop_size), Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor( div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize, ]), dense_sample=args.dense_sample), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) # define loss function (criterion) and optimizer if args.loss_type == 'nll': criterion = torch.nn.CrossEntropyLoss().cuda() else: raise ValueError("Unknown loss type") for group in policies: print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format( group['name'], len(group['params']), group['lr_mult'], group['decay_mult']))) if args.evaluate: validate(val_loader, model, criterion, 0) return log_training = open( os.path.join(args.root_log, args.store_name, 'log.csv'), 'w') with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f: f.write(str(args)) tf_writer = SummaryWriter( log_dir=os.path.join(args.root_log, args.store_name)) for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps) # train for one epoch train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer) # evaluate on validation set if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: prec1 = validate(val_loader, model, criterion, epoch, log_training, tf_writer) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch) output_best = 'Best Prec@1: %.3f\n' % (best_prec1) print(output_best) log_training.write(output_best + '\n') log_training.flush() save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_prec1': best_prec1, }, is_best)
def main(): global args, best_prec1 args = parser.parse_args() num_class = opts.num_class full_arch_name = args.arch if args.shift: full_arch_name += '_shift{}_{}'.format(args.shift_div, args.shift_place) if args.temporal_pool: full_arch_name += '_tpool' args.store_name = '_'.join([ 'TSM', args.dataset, args.modality, full_arch_name, args.consensus_type, 'segment%d' % args.num_segments, 'e{}'.format(args.epochs) ]) if args.pretrain != 'imagenet': args.store_name += '_{}'.format(args.pretrain) if args.lr_type != 'step': args.store_name += '_{}'.format(args.lr_type) if args.non_local > 0: args.store_name += '_nl' print('storing name: ' + args.store_name) # check_rootfolders() model = TSN(num_class, args.num_segments, args.modality, base_model=args.arch, consensus_type=args.consensus_type, dropout=args.dropout, img_size=args.img_size, partial_bn=not args.no_partialbn, pretrain=args.pretrain, is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place, fc_lr5=True, temporal_pool=args.temporal_pool, non_local=args.non_local) crop_size = model.crop_size scale_size = model.scale_size input_mean = model.input_mean input_std = model.input_std policies = model.get_optim_policies() train_augmentation = model.get_augmentation( flip=False if 'something' in args.dataset or 'jester' in args.dataset else True) model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda() model.apply(weights_init) optimizer = torch.optim.SGD(policies, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.resume: if args.temporal_pool: # early temporal pool so that we can load the state_dict make_temporal_pool(model.module.base_model, args.num_segments) if os.path.isfile(args.resume): print(("=> loading checkpoint '{}'".format(args.resume))) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] # best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict'], strict=False) # optimizer.load_state_dict(checkpoint['optimizer']) print(("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch']))) else: print(("=> no checkpoint found at '{}'".format(args.resume))) if args.temporal_pool and not args.resume: make_temporal_pool(model.module.base_model, args.num_segments) cudnn.benchmark = True # Data loading code if args.modality != 'RGBDiff': normalize = GroupNormalize(input_mean, input_std) else: normalize = IdentityTransform() fr_r = open(opts.NUM_LABEL_R, 'r+') w2n = eval(fr_r.read()) fr_r.close() fr = open(opts.NUM_LABEL, 'r+') n2w = eval(fr.read()) fr.close() train_loader, val_loader, test_loader = None, None, None if args.mode != 'test': lip_dict, video_list = opts.file_deal(opts.TRAIN_DATA, w2n) train_num = int(len(video_list) * 0.95) train_loader = torch.utils.data.DataLoader(TSNDataSet( opts.TRAIN_DATA, args.mode, num_segments=args.num_segments, img_size=args.img_size, lip_dict=lip_dict, video_list=video_list[:train_num], transform=torchvision.transforms.Compose([ train_augmentation, GroupMultiScaleCrop(args.img_size, [1, .875, .75, .66]), Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor( div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize, ])), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(TSNDataSet( opts.TRAIN_DATA, args.mode, num_segments=args.num_segments, img_size=args.img_size, lip_dict=lip_dict, video_list=video_list[train_num:], transform=torchvision.transforms.Compose([ GroupScale(int(scale_size)), GroupCenterCrop(crop_size), Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor( div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize, ])), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) else: lip_dict, video_list = opts.file_deal(opts.TEST_DATA, w2n) test_loader = torch.utils.data.DataLoader(TSNDataSet_infer( opts.TEST_DATA, num_segments=args.num_segments, img_size=args.img_size, lip_dict=lip_dict, video_list=video_list, transform=torchvision.transforms.Compose([ GroupScale(int(scale_size)), GroupCenterCrop(crop_size), Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor( div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize, ])), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) # define loss function (criterion) and optimizer if args.loss_type == 'nll': criterion = torch.nn.CrossEntropyLoss().cuda() else: raise ValueError("Unknown loss type") for group in policies: print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format( group['name'], len(group['params']), group['lr_mult'], group['decay_mult']))) if args.mode == 'test': if args.sub == 'sub': inference(test_loader, model, n2w) else: inferencefusion(test_loader, model, n2w) return # 开始训练 for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps) # train for one epoch train(train_loader, model, criterion, optimizer, epoch) # evaluate on validation set if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: prec1 = validate(val_loader, model, criterion, epoch, n2w) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) output_best = 'Best Prec@1: %.3f\n' % (best_prec1) print(output_best) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_prec1': best_prec1, }, is_best)