Exemplo n.º 1
0
    def __init__(self, modality, checkpoint, arena_mask_path):

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = TSN(2,
                    8,
                    modality,
                    base_model='resnet18',
                    consensus_type='avg',
                    dropout=0.5,
                    img_feature_dim=256,
                    partial_bn=False,
                    pretrain='imagenet',
                    is_shift=True,
                    shift_div=8,
                    shift_place='blockres',
                    fc_lr5=False,
                    temporal_pool=False,
                    non_local=False)

        # Get Model complexity
        macs, params = get_model_complexity_info(
            model, (24, 224, 224),
            as_strings=True,
            print_per_layer_stat=False,
            verbose=True)  # noqa: E128, E501
        print('---{:<30}  {:<8}'.format('Computational complexity: ', macs))
        print('{:<30}  {:<8}'.format('Number of parameters: ', params))

        # Define transforms
        crop_size = model.crop_size
        scale_size = model.scale_size
        input_mean = model.input_mean
        input_std = model.input_std
        self.transform = torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=False),
            ToTorchFormatTensor(div=True),
            GroupNormalize(input_mean, input_std),
        ])

        # Load TSM model
        model = torch.nn.DataParallel(model, device_ids=1).to(device)
        model.load_state_dict(
            torch.load(checkpoint, map_location=device)['state_dict'])
        self.model = model
        self.model.eval()

        # Frame samples to be selected in a clip
        self.action_names = ['explore', 'investigate']
        self.rgb_sample = [2, 6, 9, 13, 17, 20, 24,
                           28]  # [4, 12, 19, 26, 34, 41, 48, 56]

        self.arena_mask = cv2.imread(arena_mask_path)
        if self.arena_mask is None:
            print("Arena Mask not loaded: %s" % arena_mask_path)
            exit(0)
Exemplo n.º 2
0
    def __init__(self, modality, checkpoint):

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = TSN(2,
                    8,
                    modality,
                    base_model='resnet18',
                    consensus_type='avg',
                    dropout=0.5,
                    img_feature_dim=256,
                    partial_bn=False,
                    pretrain='imagenet',
                    is_shift=True,
                    shift_div=8,
                    shift_place='blockres',
                    fc_lr5=False,
                    temporal_pool=False,
                    non_local=False)

        ## Define transforms
        crop_size = model.crop_size
        scale_size = model.scale_size
        input_mean = model.input_mean
        input_std = model.input_std
        self.transform = torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=False),
            ToTorchFormatTensor(div=True),
            GroupNormalize(input_mean, input_std),
        ])

        ## Load TSM model
        model = torch.nn.DataParallel(model, device_ids=1).to(device)
        model.load_state_dict(torch.load(checkpoint)['state_dict'])
        self.model = model
        self.model.eval()

        self.action_names = ['explore', 'investigate']
        self.rgb_sample = [4, 12, 19, 26, 34, 41, 48, 56]
Exemplo n.º 3
0
def load_src_model():
    model = TSN(2, 8, 'RGB',
                base_model='resnet50',
                consensus_type='avg',
                img_feature_dim=256,
                pretrain='imagenet',
                is_shift=True, shift_div=8, shift_place='blockres',
                non_local=False,
                )

    modelpath = '/nfs/volume-95-7/temporal-shift-module/checkpoint/TSM_videos_1218_RGB_resnet50_shift8_blockres_avg_segment8_e120_pr8_ext0.1/ckpt.best.pth.tar'
    checkpoint = torch.load(modelpath)
    checkpoint = checkpoint['state_dict']

    # base_dict = {('base_model.' + k).replace('base_model.fc', 'new_fc'): v for k, v in list(checkpoint.items())}
    base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())}
    replace_dict = {'base_model.classifier.weight': 'new_fc.weight',
                    'base_model.classifier.bias': 'new_fc.bias',
                    }
    for k, v in replace_dict.items():
        if k in base_dict:
            base_dict[v] = base_dict.pop(k)

    model.load_state_dict(base_dict)
    model.eval()

    # example = torch.ones(1, 8, 3, 224, 224)

    example = torch.eye(224)
    example = example.expand((1, 8, 3, 224, 224))

    y = model(example)
    print("src_model output: ", y)
Exemplo n.º 4
0
def cvt_model():
    print("===> Loading model")

    model = TSN(2, 8, 'RGB',
              base_model='resnet50',
              consensus_type='avg',
              img_feature_dim=256,
              pretrain='imagenet',
              is_shift=True, shift_div=8, shift_place='blockres',
              non_local=False,
              )

    modelpath = '/nfs/volume-95-7/temporal-shift-module/checkpoint/TSM_videos_1218_RGB_resnet50_shift8_blockres_avg_segment8_e120_pr8_ext0.1/ckpt.best.pth.tar'
    checkpoint = torch.load(modelpath)
    checkpoint = checkpoint['state_dict']

    # base_dict = {('base_model.' + k).replace('base_model.fc', 'new_fc'): v for k, v in list(checkpoint.items())}
    base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())}
    replace_dict = {'base_model.classifier.weight': 'new_fc.weight',
                    'base_model.classifier.bias': 'new_fc.bias',
                    }
    for k, v in replace_dict.items():
        if k in base_dict:
            base_dict[v] = base_dict.pop(k)

    model.load_state_dict(base_dict)

    # 模型转换,Torch Script
    model.cuda()
    model.eval()
    example = torch.rand(1,8,3,224,224).cuda()
    y = model(example)
    print(y.shape)
    traced_script_module = torch.jit.trace(model, example)
    print(traced_script_module.code)
    # traced_script_module = torch.jit.script(model)
    # print(traced_script_module.code)
    # output = traced_script_module(torch.rand(1, 1, 224, 224))
    traced_script_module.save("tsm_with_1218.pt")

    print("Export of model.pt complete!")
Exemplo n.º 5
0
def load_model(weights):
    global num_class
    is_shift, shift_div, shift_place = parse_shift_option_from_log_name(
        weights)
    if 'RGB' in weights:
        modality = 'RGB'
    elif 'Depth' in weights:
        modality = 'Depth'
    else:
        modality = 'Flow'

    if 'concatAll' in weights:
        concat = "All"
    elif "concatFirst" in weights:
        concat = "First"
    else:
        concat = ""

    if 'extra' in this_weights:
        extra_temporal_modeling = True

    args.prune = ""

    if 'conv1d' in weights:
        args.crop_fusion_type = "conv1d"
    else:
        args.crop_fusion_type = "avg"

    this_arch = weights.split('TSM_')[1].split('_')[2]
    modality_list.append(modality)
    num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset(
        args.dataset, modality)

    print('=> shift: {}, shift_div: {}, shift_place: {}'.format(
        is_shift, shift_div, shift_place))
    net = TSN(num_class,
              int(args.test_segments) if is_shift else 1,
              modality,
              base_model=this_arch,
              consensus_type=args.crop_fusion_type,
              img_feature_dim=args.img_feature_dim,
              pretrain=args.pretrain,
              is_shift=is_shift,
              shift_div=shift_div,
              shift_place=shift_place,
              non_local='_nl' in weights,
              concat=concat,
              extra_temporal_modeling=extra_temporal_modeling,
              prune_list=[prune_conv1in_list, prune_conv1out_list],
              is_prune=args.prune)

    if 'tpool' in weights:
        from ops.temporal_shift import make_temporal_pool
        make_temporal_pool(net.base_model,
                           args.test_segments)  # since DataParallel

    checkpoint = torch.load(weights)
    checkpoint = checkpoint['state_dict']

    # base_dict = {('base_model.' + k).replace('base_model.fc', 'new_fc'): v for k, v in list(checkpoint.items())}
    base_dict = {
        '.'.join(k.split('.')[1:]): v
        for k, v in list(checkpoint.items())
    }
    replace_dict = {
        'base_model.classifier.weight': 'new_fc.weight',
        'base_model.classifier.bias': 'new_fc.bias',
    }
    for k, v in replace_dict.items():
        if k in base_dict:
            base_dict[v] = base_dict.pop(k)

    net.load_state_dict(base_dict)

    input_size = net.scale_size if args.full_res else net.input_size
    if args.test_crops == 1:
        cropping = torchvision.transforms.Compose([
            GroupScale(net.scale_size),
            GroupCenterCrop(input_size),
        ])
    elif args.test_crops == 3:  # do not flip, so only 5 crops
        cropping = torchvision.transforms.Compose(
            [GroupFullResSample(input_size, net.scale_size, flip=False)])
    elif args.test_crops == 5:  # do not flip, so only 5 crops
        cropping = torchvision.transforms.Compose(
            [GroupOverSample(input_size, net.scale_size, flip=False)])
    elif args.test_crops == 10:
        cropping = torchvision.transforms.Compose(
            [GroupOverSample(input_size, net.scale_size)])
    else:
        raise ValueError(
            "Only 1, 5, 10 crops are supported while we got {}".format(
                args.test_crops))

    transform = torchvision.transforms.Compose([
        cropping,
        Stack(roll=(this_arch in ['BNInception', 'InceptionV3'])),
        ToTorchFormatTensor(
            div=(this_arch not in ['BNInception', 'InceptionV3'])),
        GroupNormalize(net.input_mean, net.input_std),
    ])

    if args.gpus is not None:
        devices = [args.gpus[i] for i in range(args.workers)]
    else:
        devices = list(range(args.workers))

    net = torch.nn.DataParallel(net.cuda())
    return is_shift, net, transform
Exemplo n.º 6
0
def main():
    global args, best_prec1, TRAIN_SAMPLES
    args = parser.parse_args()

    num_class, args.train_list, args.val_list, args.test_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset,
                                                                                                      args.modality)
    if os.path.exists(args.test_list):
        args.val_list = args.test_list


    model = TSN(num_class, args.num_segments, args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place,
                fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
                temporal_pool=args.temporal_pool,
                non_local=args.non_local,
                tin=args.tin)


    crop_size = args.crop_size
    scale_size = args.scale_size
    input_mean = [0.485, 0.456, 0.406]
    input_std = [0.229, 0.224, 0.225]

    print(args.gpus)
    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    if os.path.isfile(args.resume_path):
        print(("=> loading checkpoint '{}'".format(args.resume_path)))
        checkpoint = torch.load(args.resume_path)
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        print(("=> loaded checkpoint '{}' (epoch {})"
               .format(args.evaluate, checkpoint['epoch'])))
    else:
        print(("=> no checkpoint found at '{}'".format(args.resume_path)))


    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    if args.random_crops == 1:
       crop_aug = GroupCenterCrop(args.crop_size)
    elif args.random_crops == 3:
       crop_aug = GroupFullResSample(args.crop_size, args.scale_size, flip=False)
    elif args.random_crops == 5:
       crop_aug = GroupOverSample(args.crop_size, args.scale_size, flip=False)
    else:
       crop_aug = MultiGroupRandomCrop(args.crop_size, args.random_crops),


    test_dataset = TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl=prefix,
                   multi_class=args.multi_class,
                   transform=torchvision.transforms.Compose([
                       GroupScale(int(args.scale_size)),
                       crop_aug,
                       Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                       ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
                       normalize,
                   ]), dense_sample=args.dense_sample,
                   test_mode=True,
                   temporal_clips=args.temporal_clips)


    test_loader = torch.utils.data.DataLoader(
            test_dataset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    test(test_loader, model, args.start_epoch)
Exemplo n.º 7
0
Arquivo: main.py Projeto: CV-IP/TDN
def main():
    global args, best_prec1
    args = parser.parse_args()

    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group(backend='nccl', init_method='env://')

    num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(
        args.dataset, args.modality)
    full_arch_name = args.arch
    args.store_name = '_'.join([
        'TDN_', args.dataset, args.modality, full_arch_name,
        args.consensus_type,
        'segment%d' % args.num_segments, 'e{}'.format(args.epochs)
    ])
    if args.pretrain != 'imagenet':
        args.store_name += '_{}'.format(args.pretrain)
    if args.dense_sample:
        args.store_name += '_dense'
    if args.suffix is not None:
        args.store_name += '_{}'.format(args.suffix)

    if dist.get_rank() == 0:
        check_rootfolders()

    logger = setup_logger(output=os.path.join(args.root_log, args.store_name),
                          distributed_rank=dist.get_rank(),
                          name=f'TDN')
    logger.info('storing name: ' + args.store_name)

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                fc_lr5=(args.tune_from and args.dataset in args.tune_from))

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    for group in policies:
        logger.info(
            ('[TDN-{}]group: {} has {} params, lr_mult: {}, decay_mult: {}'.
             format(args.arch, group['name'], len(group['params']),
                    group['lr_mult'], group['decay_mult'])))

    train_augmentation = model.get_augmentation(
        flip=False if 'something' in args.dataset else True)

    cudnn.benchmark = True

    # Data loading code
    normalize = GroupNormalize(input_mean, input_std)

    train_dataset = TSNDataSet(
        args.dataset,
        args.root_path,
        args.train_list,
        num_segments=args.num_segments,
        modality=args.modality,
        image_tmpl=prefix,
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ]),
        dense_sample=args.dense_sample)

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True)

    val_dataset = TSNDataSet(
        args.dataset,
        args.root_path,
        args.val_list,
        num_segments=args.num_segments,
        modality=args.modality,
        image_tmpl=prefix,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ]),
        dense_sample=args.dense_sample)

    val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             sampler=val_sampler,
                                             drop_last=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    scheduler = get_scheduler(optimizer, len(train_loader), args)

    model = DistributedDataParallel(model.cuda(),
                                    device_ids=[args.local_rank],
                                    broadcast_buffers=True,
                                    find_unused_parameters=True)

    if args.resume:
        if os.path.isfile(args.resume):
            logger.info(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume, map_location='cpu')
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            logger.info(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            logger.info(("=> no checkpoint found at '{}'".format(args.resume)))

    if args.tune_from:
        logger.info(("=> fine-tuning from '{}'".format(args.tune_from)))
        sd = torch.load(args.tune_from)
        sd = sd['state_dict']
        model_dict = model.state_dict()
        replace_dict = []
        for k, v in sd.items():
            if k not in model_dict and k.replace('.net', '') in model_dict:
                logger.info('=> Load after remove .net: ', k)
                replace_dict.append((k, k.replace('.net', '')))
        for k, v in model_dict.items():
            if k not in sd and k.replace('.net', '') in sd:
                logger.info('=> Load after adding .net: ', k)
                replace_dict.append((k.replace('.net', ''), k))

        for k, k_new in replace_dict:
            sd[k_new] = sd.pop(k)
        keys1 = set(list(sd.keys()))
        keys2 = set(list(model_dict.keys()))
        set_diff = (keys1 - keys2) | (keys2 - keys1)
        logger.info(
            '#### Notice: keys that failed to load: {}'.format(set_diff))
        if args.dataset not in args.tune_from:  # new dataset
            logger.info('=> New dataset, do not load fc weights')
            sd = {k: v for k, v in sd.items() if 'fc' not in k}
        model_dict.update(sd)
        model.load_state_dict(model_dict)

    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))

    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))

    if args.evaluate:
        logger.info(("===========evaluate==========="))
        val_loader.sampler.set_epoch(args.start_epoch)
        prec1, prec5, val_loss = validate(val_loader, model, criterion, logger)
        if dist.get_rank() == 0:
            is_best = prec1 > best_prec1
            best_prec1 = prec1
            logger.info(("Best Prec@1: '{}'".format(best_prec1)))
            save_epoch = args.start_epoch + 1
            save_checkpoint(
                {
                    'epoch': args.start_epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'prec1': prec1,
                    'best_prec1': best_prec1,
                }, save_epoch, is_best)
        return

    for epoch in range(args.start_epoch, args.epochs):
        train_loader.sampler.set_epoch(epoch)
        train_loss, train_top1, train_top5 = train(train_loader,
                                                   model,
                                                   criterion,
                                                   optimizer,
                                                   epoch=epoch,
                                                   logger=logger,
                                                   scheduler=scheduler)
        if dist.get_rank() == 0:
            tf_writer.add_scalar('loss/train', train_loss, epoch)
            tf_writer.add_scalar('acc/train_top1', train_top1, epoch)
            tf_writer.add_scalar('acc/train_top5', train_top5, epoch)
            tf_writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], epoch)

        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            val_loader.sampler.set_epoch(epoch)
            prec1, prec5, val_loss = validate(val_loader, model, criterion,
                                              epoch, logger)
            if dist.get_rank() == 0:
                tf_writer.add_scalar('loss/test', val_loss, epoch)
                tf_writer.add_scalar('acc/test_top1', prec1, epoch)
                tf_writer.add_scalar('acc/test_top5', prec5, epoch)

                is_best = prec1 > best_prec1
                best_prec1 = max(prec1, best_prec1)
                tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch)

                logger.info(("Best Prec@1: '{}'".format(best_prec1)))
                tf_writer.flush()
                save_epoch = epoch + 1
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'prec1': prec1,
                        'best_prec1': best_prec1,
                    }, save_epoch, is_best)
def main():
        

    # options
    parser = argparse.ArgumentParser(description="TSM testing on the full validation set")
    parser.add_argument('dataset', type=str)

    # may contain splits
    parser.add_argument('--weights', type=str, default=None)
    parser.add_argument('--test_segments', type=str, default=25)
    parser.add_argument('--dense_sample', default=False, action="store_true", help='use dense sample as I3D')
    parser.add_argument('--twice_sample', default=False, action="store_true", help='use twice sample for ensemble')
    parser.add_argument('--full_res', default=False, action="store_true",
                        help='use full resolution 256x256 for test as in Non-local I3D')

    parser.add_argument('--test_crops', type=int, default=1)
    parser.add_argument('--coeff', type=str, default=None)
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
                        help='number of data loading workers (default: 8)')

    # for true test
    parser.add_argument('--test_list', type=str, default=None)
    parser.add_argument('--csv_file', type=str, default=None)

    parser.add_argument('--softmax', default=False, action="store_true", help='use softmax')

    parser.add_argument('--max_num', type=int, default=-1)
    parser.add_argument('--input_size', type=int, default=224)
    parser.add_argument('--crop_fusion_type', type=str, default='avg')
    parser.add_argument('--gpus', nargs='+', type=int, default=None)
    parser.add_argument('--img_feature_dim',type=int, default=256)
    parser.add_argument('--num_set_segments',type=int, default=1,help='TODO: select multiply set of n-frames from a video')
    parser.add_argument('--pretrain', type=str, default='imagenet')

    args = parser.parse_args()





    def accuracy(output, target, topk=(1,)):
        """Computes the precision@k for the specified values of k"""
        maxk = max(topk)
        batch_size = target.size(0)
        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


    def parse_shift_option_from_log_name(log_name):
        if 'shift' in log_name:
            strings = log_name.split('_')
            for i, s in enumerate(strings):
                if 'shift' in s:
                    break
            return True, int(strings[i].replace('shift', '')), strings[i + 1]
        else:
            return False, None, None


    weights_list = args.weights.split(',')
    test_segments_list = [int(s) for s in args.test_segments.split(',')]
    assert len(weights_list) == len(test_segments_list)
    if args.coeff is None:
        coeff_list = [1] * len(weights_list)
    else:
        coeff_list = [float(c) for c in args.coeff.split(',')]

    if args.test_list is not None:
        test_file_list = args.test_list.split(',')
    else:
        test_file_list = [None] * len(weights_list)


    data_iter_list = []
    net_list = []
    modality_list = []

    total_num = None
    for this_weights, this_test_segments, test_file in zip(weights_list, test_segments_list, test_file_list):
        is_shift, shift_div, shift_place = parse_shift_option_from_log_name(this_weights)
        if 'RGB' in this_weights:
            modality = 'RGB'
        else:
            modality = 'Flow'
        this_arch = this_weights.split('TSM_')[1].split('_')[2]
        modality_list.append(modality)
        num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset(args.dataset,
                                                                                                modality)
        print('=> shift: {}, shift_div: {}, shift_place: {}'.format(is_shift, shift_div, shift_place))
        net = TSN(num_class, this_test_segments if is_shift else 1, modality,
                base_model=this_arch,
                consensus_type=args.crop_fusion_type,
                img_feature_dim=args.img_feature_dim,
                pretrain=args.pretrain,
                is_shift=is_shift, shift_div=shift_div, shift_place=shift_place,
                non_local='_nl' in this_weights,
                )

        if 'tpool' in this_weights:
            from ops.temporal_shift import make_temporal_pool
            make_temporal_pool(net.base_model, this_test_segments)  # since DataParallel

        checkpoint = torch.load(this_weights)
        checkpoint = checkpoint['state_dict']

        # base_dict = {('base_model.' + k).replace('base_model.fc', 'new_fc'): v for k, v in list(checkpoint.items())}
        base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())}
        replace_dict = {'base_model.classifier.weight': 'new_fc.weight',
                        'base_model.classifier.bias': 'new_fc.bias',
                        }
        for k, v in replace_dict.items():
            if k in base_dict:
                base_dict[v] = base_dict.pop(k)

        net.load_state_dict(base_dict)

        input_size = net.scale_size if args.full_res else net.input_size
        if args.test_crops == 1:
            cropping = torchvision.transforms.Compose([
                GroupScale(net.scale_size),
                GroupCenterCrop(input_size),
            ])
        elif args.test_crops == 3:  # do not flip, so only 5 crops
            cropping = torchvision.transforms.Compose([
                GroupFullResSample(input_size, net.scale_size, flip=False)
            ])
        elif args.test_crops == 5:  # do not flip, so only 5 crops
            cropping = torchvision.transforms.Compose([
                GroupOverSample(input_size, net.scale_size, flip=False)
            ])
        elif args.test_crops == 10:
            cropping = torchvision.transforms.Compose([
                GroupOverSample(input_size, net.scale_size)
            ])
        else:
            raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(args.test_crops))

        data_loader = torch.utils.data.DataLoader(
                TSNDataSet(root_path, test_file if test_file is not None else val_list, num_segments=this_test_segments,
                        new_length=1 if modality == "RGB" else 5,
                        modality=modality,
                        image_tmpl=prefix,
                        test_mode=True,
                        remove_missing=len(weights_list) == 1,
                        transform=torchvision.transforms.Compose([
                            cropping,
                            Stack(roll=(this_arch in ['BNInception', 'InceptionV3'])),
                            ToTorchFormatTensor(div=(this_arch not in ['BNInception', 'InceptionV3'])),
                            GroupNormalize(net.input_mean, net.input_std),
                        ]), dense_sample=args.dense_sample, twice_sample=args.twice_sample),
                batch_size=args.batch_size, shuffle=False,
                num_workers=args.workers, pin_memory=True,
        )

        if args.gpus is not None:
            devices = [args.gpus[i] for i in range(args.workers)]
        else:
            devices = list(range(args.workers))

        net = torch.nn.DataParallel(net.cuda())
        net.eval()

        data_gen = enumerate(data_loader)

        if total_num is None:
            total_num = len(data_loader.dataset)
        else:
            assert total_num == len(data_loader.dataset)

        data_iter_list.append(data_gen)
        net_list.append(net)


    output = []


    def eval_video(video_data, net, this_test_segments, modality):
        net.eval()
        with torch.no_grad():
            i, data, label = video_data
            batch_size = label.numel()
            num_crop = args.test_crops
            if args.dense_sample:
                num_crop *= 10  # 10 clips for testing when using dense sample

            if args.twice_sample:
                num_crop *= 2

            if modality == 'RGB':
                length = 3
            elif modality == 'Flow':
                length = 10
            elif modality == 'RGBDiff':
                length = 18
            else:
                raise ValueError("Unknown modality "+ modality)

            data_in = data.view(-1, length, data.size(2), data.size(3))
            if is_shift:
                data_in = data_in.view(batch_size * num_crop, this_test_segments, length, data_in.size(2), data_in.size(3))
            rst = net(data_in)
            rst = rst.reshape(batch_size, num_crop, -1).mean(1)

            if args.softmax:
                # take the softmax to normalize the output to probability
                rst = F.softmax(rst, dim=1)

            rst = rst.data.cpu().numpy().copy()

            if net.module.is_shift:
                rst = rst.reshape(batch_size, num_class)
            else:
                rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class))

            return i, rst, label


    proc_start_time = time.time()
    max_num = args.max_num if args.max_num > 0 else total_num

    top1 = AverageMeter()
    top5 = AverageMeter()

    for i, data_label_pairs in enumerate(zip(*data_iter_list)):
        with torch.no_grad():
            if i >= max_num:
                break
            this_rst_list = []
            this_label = None
            for n_seg, (_, (data, label)), net, modality in zip(test_segments_list, data_label_pairs, net_list, modality_list):
                rst = eval_video((i, data, label), net, n_seg, modality)
                this_rst_list.append(rst[1])
                this_label = label
            assert len(this_rst_list) == len(coeff_list)
            for i_coeff in range(len(this_rst_list)):
                this_rst_list[i_coeff] *= coeff_list[i_coeff]
            ensembled_predict = sum(this_rst_list) / len(this_rst_list)

            for p, g in zip(ensembled_predict, this_label.cpu().numpy()):
                output.append([p[None, ...], g])
            cnt_time = time.time() - proc_start_time
            prec1, prec5 = accuracy(torch.from_numpy(ensembled_predict), this_label, topk=(1, 5))
            top1.update(prec1.item(), this_label.numel())
            top5.update(prec5.item(), this_label.numel())
            if i % 20 == 0:
                print('video {} done, total {}/{}, average {:.3f} sec/video, '
                    'moving Prec@1 {:.3f} Prec@5 {:.3f}'.format(i * args.batch_size, i * args.batch_size, total_num,
                                                                float(cnt_time) / (i+1) / args.batch_size, top1.avg, top5.avg))

    video_pred = [np.argmax(x[0]) for x in output]
    video_pred_top5 = [np.argsort(np.mean(x[0], axis=0).reshape(-1))[::-1][:5] for x in output]

    video_labels = [x[1] for x in output]


    if args.csv_file is not None:
        print('=> Writing result to csv file: {}'.format(args.csv_file))
        with open(test_file_list[0].replace('test_videofolder.txt', 'category.txt')) as f:
            categories = f.readlines()
        categories = [f.strip() for f in categories]
        with open(test_file_list[0]) as f:
            vid_names = f.readlines()
        vid_names = [n.split(' ')[0] for n in vid_names]
        print(vid_names)
        assert len(vid_names) == len(video_pred)
        if args.dataset != 'somethingv2':  # only output top1
            with open(args.csv_file, 'w') as f:
                for n, pred in zip(vid_names, video_pred):
                    f.write('{};{}\n'.format(n, categories[pred]))
        else:
            with open(args.csv_file, 'w') as f:
                for n, pred5 in zip(vid_names, video_pred_top5):
                    fill = [n]
                    for p in list(pred5):
                        fill.append(p)
                    f.write('{};{};{};{};{};{}\n'.format(*fill))


    cf = confusion_matrix(video_labels, video_pred).astype(float)

    np.save('cm.npy', cf)
    cls_cnt = cf.sum(axis=1)
    cls_hit = np.diag(cf)

    cls_acc = cls_hit / cls_cnt
    print(cls_acc)
    upper = np.mean(np.max(cf, axis=1) / cls_cnt)
    print('upper bound: {}'.format(upper))

    print('-----Evaluation is finished------')
    print('Class Accuracy {:.02f}%'.format(np.mean(cls_acc) * 100))
    print('Overall Prec@1 {:.02f}% Prec@5 {:.02f}%'.format(top1.avg, top5.avg))
Exemplo n.º 9
0
def get_tsm(num_classes=3, pretrain_set='kinetics'):
    if pretrain_set == 'kinetics':
        base_model = "resnet50"
        this_weights = "pretrained/TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e100_dense_nl.pth"
        original_num_classes = 400
        non_local = True
        print("Using kinetics")
    else:
        base_model = "resnet101"
        this_weights = "pretrained/TSM_somethingv2_RGB_resnet101_shift8_blockres_avg_segment8_e45.pth"
        # base_model = "resnet50"
        # this_weights = "pretrained/TSM_somethingv2_RGB_resnet50_shift8_blockres_avg_segment8_e45.pth"
        original_num_classes = 174
        non_local = False

    modality = "RGB"

    segments = 8

    consensus_type = "avg"
    img_feature_dim = 256
    pretrain = True
    is_shift = True
    shift_div = 8
    shift_place = "blockres"

    net = TSN(
        original_num_classes,
        segments,
        modality,
        base_model=base_model,
        consensus_type=consensus_type,
        img_feature_dim=img_feature_dim,
        pretrain=pretrain,
        is_shift=is_shift,
        shift_div=shift_div,
        shift_place=shift_place,
        non_local=non_local,
    )

    checkpoint = torch.load(this_weights)
    checkpoint = checkpoint['state_dict']
    base_dict = {
        '.'.join(k.split('.')[1:]): v
        for k, v in list(checkpoint.items())
    }
    replace_dict = {
        'base_model.classifier.weight': 'new_fc.weight',
        'base_model.classifier.bias': 'new_fc.bias',
    }
    for k, v in replace_dict.items():
        if k in base_dict:
            base_dict[v] = base_dict.pop(k)

    net.load_state_dict(base_dict)
    #
    # for param in net.parameters():
    #     param.requires_grad = False
    #
    # for param in net.base_model.layer4.parameters():
    #     param.requires_grad = True

    net.new_fc = torch.nn.Linear(2048, num_classes)
    return net
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=rank)
    else:
        rank = 0
    # create model
    num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset, args.modality)
    full_arch_name = args.arch
    if args.shift:
        full_arch_name += '_shift{}_{}'.format(args.shift_div, args.shift_place)
    if args.temporal_pool:
        full_arch_name += '_tpool'
    args.store_name = '_'.join(
        ['TSM', args.dataset, args.modality, full_arch_name, args.consensus_type, 'segment%d' % args.num_segments,
         'e{}'.format(args.epochs)])
    if args.pretrain != 'imagenet':
        args.store_name += '_{}'.format(args.pretrain)
    if args.lr_type != 'step':
        args.store_name += '_{}'.format(args.lr_type)
    if args.dense_sample:
        args.store_name += '_dense'
    if args.non_local > 0:
        args.store_name += '_nl'
    args.store_name += '_lr{}'.format(args.lr)
    args.store_name += '_wd{:.1e}'.format(args.weight_decay)
    args.store_name += '_do{}'.format(args.dropout)
    if args.suffix is not None:
        args.store_name += '_{}'.format(args.suffix)
    print('storing name: ' + args.store_name)

    check_rootfolders(args, rank)

    model = TSN(num_class, args.num_segments, args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place,
                fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
                temporal_pool=args.temporal_pool,
                non_local=args.non_local)
    
    # first synchronization of initial weights
    # sync_initial_weights(model, args.rank, args.world_size)
    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    if rank == 0:
        print(model)
    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation(flip=False if 'something' in args.dataset or 'jester' in args.dataset else True)

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have on a node
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    optimizer = torch.optim.SGD(policies, args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            if args.start_epoch == 1:
                args.start_epoch = checkpoint['epoch'] + 1
            best_acc1 = checkpoint['best_acc1']
#             if args.gpu is not None:
#                 # best_acc1 may be from a checkpoint from a different GPU
#                 best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_dataset = TSNDataSet(args.dataset, args.root_path, args.train_list, num_segments=args.num_segments,
                       new_length=data_length,
                       modality=args.modality,
                       image_tmpl=prefix,
                       transform=torchvision.transforms.Compose([train_augmentation,
                           Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                           ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
                           normalize]), 
                      dense_sample=args.dense_sample)
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, drop_last=True, sampler=train_sampler)
    
    val_loader = torch.utils.data.DataLoader(
        TSNDataSet(args.dataset, args.root_path, args.val_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl=prefix,
                   random_shift=False,
                   transform=torchvision.transforms.Compose([
                       GroupScale(int(scale_size)),
                       GroupCenterCrop(crop_size),
                       Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                       ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
                       normalize,
                   ]), dense_sample=args.dense_sample),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return
    
    log_training = open(os.path.join(args.root_model, args.store_name, 'log.csv'), 'w')
    with open(os.path.join(args.root_model, args.store_name, 'args.txt'), 'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(log_dir=os.path.join(args.root_model, args.store_name))
    for epoch in range(args.start_epoch, args.epochs+1):
        if args.distributed:
            train_sampler.set_epoch(epoch)
#         adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps, args)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer, args, rank)
        if rank % ngpus_per_node == 0:
            save_checkpoint({
                'epoch': epoch,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer' : optimizer.state_dict(),
            }, False, args, rank)
        
        if epoch % 5 == 0 and rank % ngpus_per_node == 0:
            save_checkpoint({
                'epoch': epoch,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }, False, args, rank, e=epoch)

        # evaluate on validation set
        is_best = False
        if epoch % args.eval_freq == 0 or epoch == args.epochs:
            acc1 = validate(val_loader, model, criterion, epoch, args, rank, log_training, tf_writer)

            # remember best acc@1 and save checkpoint
            is_best = acc1 > best_acc1
            best_acc1 = max(acc1, best_acc1)

        if not args.multiprocessing_distributed or (args.multiprocessing_distributed and rank % ngpus_per_node == 0):
            save_checkpoint({
                'epoch': epoch,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer' : optimizer.state_dict(),
            }, is_best, args, rank)
Exemplo n.º 11
0
    has_tam = parse_shift_option_from_log_name(this_weights)
    if 'RGB' in this_weights:
        modality = 'RGB'
    else:
        modality = 'Flow'
    this_arch = this_weights.split('/')[-2].split('_')[2]
    modality_list.append(modality)
    num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset(
        args.dataset, modality)
    print('=> TAM : {}, {} sampling'.format(has_tam, args.sample))
    net = TSN(
        num_class,
        this_test_segments if has_tam else 1,
        modality,
        base_model=this_arch,
        consensus_type=args.crop_fusion_type,
        img_feature_dim=args.img_feature_dim,
        pretrain=args.pretrain,
        tam=True,
        non_local='_nl' in this_weights,
    )

    sf_ckpt = torch.load(this_weights, map_location=torch.device('cpu'))
    sf_weights = sf_ckpt['state_dict']
    tam_ckpt = net.state_dict()
    # print(base_dict.keys())
    # exit()
    base_dict = {}
    for k, v in sf_weights.items():
        if 'self_conv.conv_f' in k:
            k = k.replace('self_conv.conv_f', 'tam.G')
Exemplo n.º 12
0
def main():
    # settings
    global args, best_prec1
    args = parser.parse_args()
    n_class, args.train_list, args.val_list, args.test_list, prefix = dataset_config.dataset(
    )
    full_arch_name = args.arch
    if args.shift:
        full_arch_name += '_shift{}'.format(args.shift_div)
    args.store_name = '_'.join(
        ['tsm', full_arch_name,
         'segment%d' % args.num_segments])
    print('storing name: ' + args.store_name)
    check_rootfolders(args.root_log, args.root_model, args.store_name)

    # tsn model added temporal shift module
    model = TSN(n_class,
                args.num_segments,
                base_model=args.arch,
                dropout=args.dropout,
                partial_bn=not args.no_partialbn,
                is_shift=args.shift,
                shift_div=args.shift_div)

    # preprocessing for input
    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation(flip=False)

    # optimizer
    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # cuda and cudnn
    try:
        model = nn.DataParallel(model).cuda()
    except:
        model = model.cuda()
    cudnn.benchmark = True

    # data loader
    normalize = GroupNormalize(input_mean, input_std)
    train_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.train_list,
        num_segments=args.num_segments,
        image_tmpl=prefix,
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=False),
            ToTorchFormatTensor(div=True), normalize
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=False,
                                               drop_last=True)

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.val_list,
        num_segments=args.num_segments,
        image_tmpl=prefix,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=False),
            ToTorchFormatTensor(div=True), normalize
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=False)

    test_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.test_list,
        num_segments=args.num_segments,
        image_tmpl=prefix,
        random_shift=False,
        test_mode=True,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=False),
            ToTorchFormatTensor(div=True), normalize
        ])),
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=False)

    # loss function
    criterion = nn.CrossEntropyLoss().cuda()
    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    # tensorboard
    time_stamp = "{0:%Y-%m-%dT%H-%M-%S/}".format(datetime.now())

    # train
    if args.mode == 'train':
        log_training = open(
            os.path.join(args.root_log, args.store_name, time_stamp,
                         'log.csv'), 'w')
        tf_writer = SummaryWriter(
            '{}/{}/'.format(args.root_log, args.store_name) + time_stamp)
        for epoch in range(args.start_epoch, args.epochs):
            adjust_learning_rate(optimizer, epoch, args.lr_steps, args.lr,
                                 args.weight_decay)
            train(train_loader, model, criterion, optimizer, epoch,
                  log_training, tf_writer)

            # evaluate on validation set
            if (epoch + 1) % args.eval_freq == 0:
                prec1 = validate(val_loader, model, criterion, epoch,
                                 log_training, tf_writer)

                # remember best precision and save checkpoint
                is_best = prec1 >= best_prec1
                best_prec1 = max(prec1, best_prec1)
                output_best = 'Best Prec@1: %.2f\n' % (best_prec1)
                print(output_best)
                log_training.write(output_best + '\n')
                log_training.flush()

                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'best_prec1': best_prec1,
                    }, is_best, args.root_model, args.store_name)
                tf_writer.close()

    # test
    checkpoint = '%s/%s/ckpt.best.pth.tar' % (args.root_model, args.store_name)
    test(test_loader, model, checkpoint, time_stamp)
Exemplo n.º 13
0
			  extra_temporal_modeling = extra_temporal_modeling,
			  prune_list = [prune_conv1in_list, prune_conv1out_list],
			  is_prune = args.prune,
			  )
	'''
    net = TSN(
        num_class,
        this_test_segments if is_shift else 1,
        modality,
        base_model=this_arch,
        new_length=2 if args.data_fuse else None,
        consensus_type=args.crop_fusion_type,
        #dropout=args.dropout,
        img_feature_dim=args.img_feature_dim,
        #partial_bn=not args.no_partialbn,
        pretrain=args.pretrain,
        is_shift=args.shift,
        shift_div=args.shift_div,
        shift_place=args.shift_place,
        #fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
        #temporal_pool=args.temporal_pool,
        non_local='_nl' in this_weights,
        concat=concat,
        extra_temporal_modeling=extra_temporal_modeling,
        prune_list=[prune_conv1in_list, prune_conv1out_list],
        is_prune=args.prune,
    )
    print(net)
    #print(args.shift)
    #exit()
    if 'tpool' in this_weights:
        from ops.temporal_shift import make_temporal_pool
Exemplo n.º 14
0
                        thickness)
    return image


if __name__ == '__main__':
    args = parser.parse_args()

    arch = 'resnet50'
    tsn = TSN(len(action_to_idx),
              args.num_segments,
              'RGB',
              base_model=arch,
              consensus_type='avg',
              dropout=0.5,
              img_feature_dim=256,
              partial_bn=False,
              pretrain='imagenet',
              is_shift=True,
              shift_div=8,
              shift_place='blockres',
              fc_lr5=False,
              temporal_pool=False,
              non_local=False).to(args.device)
    model = torch.nn.DataParallel(tsn, device_ids=None).to(args.device)
    sd = torch.load(args.model,
                    map_location=torch.device(args.device))['state_dict']
    model.load_state_dict(sd)
    model.eval()

    meta = pd.DataFrame(columns=['action', 'time_start', 'time_end'])
Exemplo n.º 15
0
def main():
    global args, best_prec1, least_loss
    least_loss = 1000
    args = parser.parse_args()
    if os.path.exists(os.path.join(args.root_log, "error.log")):
        os.remove(os.path.join(args.root_log, "error.log"))
    logging.basicConfig(
        level=logging.DEBUG,
        filename=os.path.join(args.root_log, "error.log"),
        filemode='a',
        format=
        '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
    )

    # log_handler = open(os.path.join(args.root_log,"error.log"),"w")
    # sys.stdout = log_handler
    if args.root_path:
        num_class, args.train_list, args.val_list, _, prefix = dataset_config.return_dataset(
            args.dataset, args.modality)
        args.train_list = os.path.join(args.root_log,
                                       "kf1_train_anno_lijun_iod.json")
        args.test_list = os.path.join(args.root_log,
                                      "kf1_test_anno_lijun_iod.json")
    else:
        num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(
            args.dataset, args.modality)

    full_arch_name = args.arch
    if args.shift:
        full_arch_name += '_shift{}_{}'.format(args.shift_div,
                                               args.shift_place)
    if args.temporal_pool:
        full_arch_name += '_tpool'
    args.store_name = '_'.join([
        'TSA', args.dataset, args.modality, full_arch_name,
        args.consensus_type,
        'segment%d' % args.num_segments, 'e{}'.format(args.epochs)
    ])
    # if args.pretrain != 'imagenet':
    #     args.store_name += '_{}'.format(args.pretrain)
    if args.lr_type != 'step':
        args.store_name += '_{}'.format(args.lr_type)
    if args.dense_sample:
        args.store_name += '_dense'
    if args.non_local > 0:
        args.store_name += '_nl'
    if args.suffix is not None:
        args.store_name += '_{}'.format(args.suffix)
    print('storing name: ' + args.store_name)

    check_rootfolders()

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                is_shift=args.shift,
                shift_div=args.shift_div,
                shift_place=args.shift_place,
                fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
                temporal_pool=args.temporal_pool,
                non_local=args.non_local,
                is_TSA=args.tsa,
                is_sTSA=args.stsa,
                is_tTSA=args.ttsa,
                shift_diff=args.shift_diff,
                shift_groups=args.shift_groups,
                is_ME=args.me,
                is_3D=args.is_3D,
                cfg_file=args.cfg_file)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    # policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation(
        flip=False
        if 'something' in args.dataset or 'jester' in args.dataset else True)

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()
    if args.optimizer == "sgd":
        if args.lr_scheduler:
            optimizer = torch.optim.SGD(model.parameters(),
                                        args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
        else:
            optimizer = torch.optim.SGD(policies,
                                        args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
    elif args.optimizer == "adam":
        params = get_vmz_fine_tuning_parameters(model,
                                                args.vmz_tune_last_k_layer)
        optimizer = torch.optim.Adam(params,
                                     args.lr,
                                     weight_decay=args.weight_decay)
    else:
        raise RuntimeError("not supported optimizer")

    if args.lr_scheduler:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, args.lr_steps, args.lr_scheduler_gamma)

    if args.resume:
        if args.temporal_pool:  # early temporal pool so that we can load the state_dict
            make_temporal_pool(model.module.base_model, args.num_segments)
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))

            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            # if args.lr_scheduler:
            #     scheduler.load_state_dict(checkpoint["lr_scheduler"])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
            logging.info(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))
            logging.error(
                ("=> no checkpoint found at '{}'".format(args.resume)))

    if args.tune_from:
        print(("=> fine-tuning from '{}'".format(args.tune_from)))
        sd = torch.load(args.tune_from)
        sd = sd['state_dict']
        model_dict = model.state_dict()
        replace_dict = []
        for k, v in sd.items():
            if k not in model_dict and k.replace('.net', '') in model_dict:
                print('=> Load after remove .net: ', k)
                replace_dict.append((k, k.replace('.net', '')))
        for k, v in model_dict.items():
            if k not in sd and k.replace('.net', '') in sd:
                print('=> Load after adding .net: ', k)
                replace_dict.append((k.replace('.net', ''), k))

        for k, k_new in replace_dict:
            sd[k_new] = sd.pop(k)
        keys1 = set(list(sd.keys()))
        keys2 = set(list(model_dict.keys()))
        set_diff = (keys1 - keys2) | (keys2 - keys1)
        print('#### Notice: keys that failed to load: {}'.format(set_diff))
        # sd = {k:v for k, v in sd.items() if k in keys2}
        sd = {k: v for k, v in sd.items() if k in keys2}
        if args.dataset not in args.tune_from:  # new dataset
            print('=> New dataset, do not load fc weights')
            sd = {
                k: v
                for k, v in sd.items()
                if 'fc' not in k and "projection" not in k
            }
        if args.modality == 'Flow' and 'Flow' not in args.tune_from:
            sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k}
        model_dict.update(sd)
        model.load_state_dict(model_dict)

    if args.temporal_pool and not args.resume:
        make_temporal_pool(model.module.base_model, args.num_segments)

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality in ['RGB', "PoseAction"]:
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5
    if not args.shuffle:
        train_loader = torch.utils.data.DataLoader(
            TSNDataSet(args.root_path,
                       args.train_list,
                       num_segments=args.num_segments,
                       new_length=data_length,
                       modality=args.modality,
                       image_tmpl=prefix,
                       transform=torchvision.transforms.Compose([
                           train_augmentation,
                           Stack(roll=(args.arch
                                       in ['BNInception', 'InceptionV3']),
                                 inc_dim=(args.arch in ["R2plus1D", "X3D"])),
                           ToTorchFormatTensor(
                               div=(args.arch
                                    not in ['BNInception', 'InceptionV3']),
                               inc_dim=(args.arch in ["R2plus1D", "X3D"])),
                           normalize,
                       ]),
                       dense_sample=args.dense_sample,
                       all_sample=args.all_sample),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=True,
            drop_last=True)  # prevent something not % n_GPU

        val_loader = torch.utils.data.DataLoader(TSNDataSet(
            args.root_path,
            args.val_list,
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            image_tmpl=prefix,
            random_shift=False,
            transform=torchvision.transforms.Compose([
                GroupScale(scale_size),
                GroupCenterCrop(crop_size),
                Stack(roll=(args.arch in ['BNInception', 'InceptionV3']),
                      inc_dim=(args.arch in ["R2plus1D", "X3D"])),
                ToTorchFormatTensor(div=(args.arch
                                         not in ['BNInception',
                                                 'InceptionV3']),
                                    inc_dim=(args.arch in ["R2plus1D",
                                                           "X3D"])),
                normalize,
            ]),
            dense_sample=args.dense_sample,
            all_sample=args.all_sample),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)

    # for group in policies:
    #     print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
    #         group['name'], len(group['params']), group['lr_mult'], group['decay_mult'])))

    if args.evaluate:
        if args.loss_type == 'nll':
            criterion = torch.nn.CrossEntropyLoss().cuda()
        elif args.loss_type == "bce":
            criterion = torch.nn.BCEWithLogitsLoss().cuda()
        elif args.loss_type == "wbce":
            class_weight, pos_weight = prep_weight(args.train_list)
            criterion = WeightedBCEWithLogitsLoss(class_weight, pos_weight)
        else:
            raise ValueError("Unknown loss type")

        val_loader = torch.utils.data.DataLoader(TSNDataSet(
            args.root_path,
            args.val_list,
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            image_tmpl=prefix,
            random_shift=False,
            transform=torchvision.transforms.Compose([
                GroupScale(scale_size),
                GroupCenterCrop(crop_size),
                Stack(roll=(args.arch in ['BNInception', 'InceptionV3']),
                      inc_dim=(args.arch in ["R2plus1D", "X3D"])),
                ToTorchFormatTensor(div=(args.arch
                                         not in ['BNInception',
                                                 'InceptionV3']),
                                    inc_dim=(args.arch in ["R2plus1D",
                                                           "X3D"])),
                normalize,
            ]),
            dense_sample=args.dense_sample,
            all_sample=args.all_sample,
            analysis=True),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
        test(val_loader, model, criterion, 0)
        return

    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))

    print(model)
    logging.info(model)
    for epoch in range(args.start_epoch, args.epochs):
        logging.info("Train Epoch {}/{} starts, estimated time 5832s".format(
            str(epoch), str(args.epochs)))
        # update data_loader
        if args.shuffle:
            gen_label(args.prop_path,
                      args.label_path,
                      args.trn_name,
                      args.train_list,
                      args.neg_rate,
                      STR=False)
            gen_label(args.prop_path,
                      args.label_path,
                      args.tst_name,
                      args.val_list,
                      args.test_rate,
                      STR=False)
            train_loader = torch.utils.data.DataLoader(
                TSNDataSet(args.root_path,
                           args.train_list,
                           num_segments=args.num_segments,
                           new_length=data_length,
                           modality=args.modality,
                           image_tmpl=prefix,
                           transform=torchvision.transforms.Compose([
                               train_augmentation,
                               Stack(roll=(args.arch
                                           in ['BNInception', 'InceptionV3']),
                                     inc_dim=(args.arch in ["R2plus1D",
                                                            "X3D"])),
                               ToTorchFormatTensor(
                                   div=(args.arch
                                        not in ['BNInception', 'InceptionV3']),
                                   inc_dim=(args.arch in ["R2plus1D", "X3D"])),
                               normalize,
                           ]),
                           dense_sample=args.dense_sample,
                           all_sample=args.all_sample),
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True,
                drop_last=True)

            val_loader = torch.utils.data.DataLoader(
                TSNDataSet(args.root_path,
                           args.val_list,
                           num_segments=args.num_segments,
                           new_length=data_length,
                           modality=args.modality,
                           image_tmpl=prefix,
                           random_shift=False,
                           transform=torchvision.transforms.Compose([
                               GroupScale(scale_size),
                               GroupCenterCrop(crop_size),
                               Stack(roll=(args.arch
                                           in ['BNInception', 'InceptionV3']),
                                     inc_dim=(args.arch in ["R2plus1D",
                                                            "X3D"])),
                               ToTorchFormatTensor(
                                   div=(args.arch
                                        not in ['BNInception', 'InceptionV3']),
                                   inc_dim=(args.arch in ["R2plus1D", "X3D"])),
                               normalize,
                           ]),
                           dense_sample=args.dense_sample,
                           all_sample=args.all_sample),
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=True)
            print(train_loader)

            # define loss function (criterion) and optimizer
        if args.loss_type == 'nll':
            criterion = torch.nn.CrossEntropyLoss().cuda()
        elif args.loss_type == "bce":
            criterion = torch.nn.BCEWithLogitsLoss().cuda()
        elif args.loss_type == "wbce":
            class_weight, pos_weight = prep_weight(args.train_list)
            criterion = WeightedBCEWithLogitsLoss(class_weight, pos_weight)
        else:
            raise ValueError("Unknown loss type")

        if not args.lr_scheduler:
            adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)
            train(train_loader, model, criterion, optimizer, epoch,
                  log_training, tf_writer)

        else:
            train(train_loader, model, criterion, optimizer, epoch,
                  log_training, tf_writer)
            scheduler.step()

        # train for one epoch

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            logging.info(
                "Test Epoch {}/{} starts, estimated time 13874s".format(
                    str(epoch // args.eval_freq),
                    str(args.epochs / args.eval_freq)))
            if args.loss_type == "wbce":
                # class_weight,pos_weight = prep_weight(args.val_list)
                criterion = torch.nn.BCEWithLogitsLoss().cuda()
            lossm = validate(val_loader, model, criterion, epoch, log_training,
                             tf_writer)

            # remember best prec@1 and save checkpoint
            is_best = lossm < least_loss
            least_loss = min(lossm, least_loss)
            tf_writer.add_scalar('lss/test_top1_best', least_loss, epoch)

            output_best = 'Best Loss: %.3f\n' % (lossm)
            logging.info(output_best)
            log_training.write(output_best + '\n')
            log_training.flush()
            if args.lr_scheduler:
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'best_prec1': least_loss,
                        'lr_scheduler': scheduler,
                    }, is_best, epoch)
            else:
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'best_prec1': least_loss,
                    }, is_best, epoch)
Exemplo n.º 16
0
        for i, s in enumerate(strings):
            if 'shift' in s:
                break
        return True, int(strings[i].replace('shift', '')), strings[i + 1]
    else:
        return False, None, None
        
is_shift, shift_div, shift_place = parse_shift_option_from_log_name(this_weights)
print(is_shift, shift_div, shift_place)


with torch.cuda.device(0):
    net = TSN(2, 1, 'RGB',
              base_model=this_arch,
              consensus_type='avg',
              img_feature_dim='225',
              #pretrain=args.pretrain,
              is_shift=is_shift, shift_div=shift_div, shift_place=shift_place,
              non_local='_nl' in this_weights,
              )
    macs, params = get_model_complexity_info(net, (1,3, 224, 224), as_strings=True,print_per_layer_stat=False, verbose=False)
    print("Using ptflops")
    print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
    print('{:<30}  {:<8}'.format('Number of parameters: ', params))
    

from thop import profile
model = net = TSN(2, 1, 'RGB',
              base_model=this_arch,
              consensus_type='avg',
              img_feature_dim='225',
              #pretrain=args.pretrain,
Exemplo n.º 17
0
def doInferecing(cap, args, GPU_FLAG):

    # switch between archs based on selected arch
    if args.get("arch") == "mobilenetv2":
        this_weights = "checkpoint/TSM_ucfcrime_RGB_mobilenetv2_shift8_blockres_avg_segment8_e25/ckpt.best.pth.tar"
    else:
        this_weights = "checkpoint/TSM_ucfcrime_RGB_resnet50_shift8_blockres_avg_segment8_e25/ckpt.best.pth.tar"

    is_shift, shift_div, shift_place = parse_shift_option_from_log_name(
        this_weights)

    modality = "RGB"

    if "RGB" in this_weights:
        modality = "RGB"

    # Get dataset categories.
    categories = ["Normal Activity", "Abnormal Activity"]
    num_class = len(categories)
    this_arch = args.get("arch")

    print("[INFO] >> Model loading weights from disk!!")

    net = TSN(
        num_class,
        1,
        modality,
        base_model=this_arch,
        consensus_type="avg",
        img_feature_dim="225",
        # pretrain=args.pretrain,
        is_shift=is_shift,
        shift_div=shift_div,
        shift_place=shift_place,
        non_local="_nl" in this_weights,
    )

    # See GPU_FLAG to check where to load the weights on CPU or GPU
    if GPU_FLAG == "y":
        checkpoint = torch.load(this_weights)
    else:
        checkpoint = torch.load(this_weights, map_location=torch.device("cpu"))

    checkpoint = checkpoint["state_dict"]

    base_dict = {
        ".".join(k.split(".")[1:]): v
        for k, v in list(checkpoint.items())
    }
    replace_dict = {
        "base_model.classifier.weight": "new_fc.weight",
        "base_model.classifier.bias": "new_fc.bias",
    }

    for k, v in replace_dict.items():
        if k in base_dict:
            base_dict[v] = base_dict.pop(k)
    net.load_state_dict(base_dict)

    print("\n[INFO] >> Model loading Successfull")

    if GPU_FLAG == "y":
        net.cuda().eval()
        skip_frames = 2
        summary(net, (1, 3, 224, 224))
    else:
        net.eval()
        skip_frames = 4

    transform = torchvision.transforms.Compose([
        Stack(roll=(this_arch in ["BNInception", "InceptionV3"])),
        ToTorchFormatTensor(
            div=(this_arch not in ["BNInception", "InceptionV3"])),
        GroupNormalize(net.input_mean, net.input_std),
    ])
    WINDOW_NAME = "Real-Time Video Action Recognition"
    # set a lower resolution for speed up
    cap.set(cv2.CAP_PROP_FRAME_WIDTH, 320)
    cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 240)

    # env variables
    full_screen = False
    cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
    cv2.resizeWindow(WINDOW_NAME, 640, 480)
    cv2.moveWindow(WINDOW_NAME, 0, 0)
    cv2.setWindowTitle(WINDOW_NAME, WINDOW_NAME)

    t = None
    i_frame = -1
    count = 0
    imageName = 0
    # variable to hold writer object
    writer = None
    c = 0

    print("Ready!")

    while cap.isOpened():
        i_frame += 1
        hasFrame, img = cap.read()  # (480, 640, 3) 0 ~ 255

        if hasFrame:
            img_tran = transform([Image.fromarray(img).convert("RGB")])
            if (i_frame % skip_frames == 0
                ):  # skip every other frame to obtain a suitable frame rate
                t1 = time.time()

                if GPU_FLAG == "y":
                    input1 = (img_tran.view(
                        -1, 3, img_tran.size(1),
                        img_tran.size(2)).unsqueeze(0).cuda())
                else:
                    input1 = img_tran.view(-1, 3, img_tran.size(1),
                                           img_tran.size(2)).unsqueeze(0)

                input = input1

                with torch.no_grad():
                    logits = net(input)
                    h_x = torch.mean(F.softmax(logits, 1), dim=0).data
                    print(
                        "<<< [INFO] >>> PROB  - | Normal: {:.2f}".format(
                            h_x[0]),
                        "| Abnormal: {:.2f} |".format(h_x[1]),
                        "Frames Rendered-",
                        count,
                    )
                    pr, li = h_x.sort(0, True)
                    probs = pr.tolist()
                    idx = li.tolist()
                    # print(probs)
                    t2 = time.time()

                print(
                    "<<< [INFO] >>>",
                    "EVENT - |",
                    categories[idx[0]],
                    "  Prob: {:.2f}| ".format(probs[0]),
                    "\n",
                )
                current_time = t2 - t1

            img = cv2.resize(img, (640, 480))
            height, width, _ = img.shape

            if categories[idx[0]] == "Abnormal Activity":
                R = 255
                G = 0
                Abnormality = True
                tempThres = probs[0]
                c += 1
                maxAbnormalProb.append(float(probs[0]))

            else:
                R = 0
                G = 255
                Abnormality = False

            cv2.putText(
                img,
                "EVENT: " + categories[idx[0]],
                (20, int(height / 16)),
                cv2.FONT_HERSHEY_SIMPLEX,
                0.7,
                (0, int(G), int(R)),
                2,
            )

            cv2.putText(
                img,
                "Confidence: {0:.2f}%".format(probs[0] * 100, "%"),
                (20, int(height - 420)),
                cv2.FONT_HERSHEY_SIMPLEX,
                0.7,
                (0, int(G), int(R)),
                2,
            )

            fps = 1 / current_time
            # if args.get('f',True):
            FpsList.append(float(fps))
            maxFps = max(FpsList)
            estFps = sum(FpsList) / len(FpsList)
            # else:
            # maxFps=-1
            # estFps=-1
            cv2.putText(
                img,
                "FPS: {0:.1f}".format(fps),
                (width - 150, int(height / 16)),
                cv2.FONT_HERSHEY_SIMPLEX,
                0.7,
                (0, 255, 255),
                2,
            )

            if writer is None:
                fourcc = cv2.VideoWriter_fourcc(*"MJPG")
                (H, W) = img.shape[:2]
                path = "./appData/Anoamly_Clips/"
                name = len(glob.glob(path + "*.avi"))

                getVidName = path + "Abnormal_Event_{0}.avi".format(name + 1)
                writer = cv2.VideoWriter(getVidName, fourcc, 30.0, (W, H),
                                         True)

            # Saving Anaomlous Event Image and Clip
            if Abnormality:
                writer.write(img)
                # record stat every two seconds if exists
                if c % 60 == 0:
                    getStatsOfAbnormalActivity()
                    # if tempThres > 0.75:

                    path = "./appData/Anoamly_Images/"
                    index = len(glob.glob(path + "*.jpg"))
                    # imageName = getFileName(path+'.jpg')
                    imageName = path + "Abnormal_Event_{0}.jpg".format(index +
                                                                       1)
                    cv2.imwrite(imageName, img)

            cv2.imshow(WINDOW_NAME, img)

            key = cv2.waitKey(1)

            if key & 0xFF == ord("q") or key == 27:  # exit
                break
            elif key == ord("F") or key == ord("f"):  # full screen
                print("Changing full screen option!")
                full_screen = not full_screen
                if full_screen:
                    print("Setting FS!!!")
                    cv2.setWindowProperty(WINDOW_NAME, cv2.WND_PROP_FULLSCREEN,
                                          cv2.WINDOW_FULLSCREEN)
                else:
                    cv2.setWindowProperty(WINDOW_NAME, cv2.WND_PROP_FULLSCREEN,
                                          cv2.WINDOW_NORMAL)
            # resetting time for next frame
            if t is None:
                t = time.time()
            else:
                nt = time.time()
                count += 1
                t = nt
        else:
            # Uncomment below lines to run code unfinitely and comment cap.release and writer,release

            # i_frame = 0
            # cap.set(cv2.CAP_PROP_POS_FRAMES,0)
            cap.release()
            writer.release()
            cv2.destroyAllWindows()

            # Clearing Variables for re-running
            # estFps=None
            # maxAbnormalProb.clear()
            # maxFps=None
    # Calculating total execution time
    execTime = time.time() - startime
    print()

    # Display Results
    print("<<< [INFO] >>> Total Abnormal Probs : ", len(maxAbnormalProb))
    print("<<< [INFO] >>> Max Abnormality Prob : {:.2f}".format(
        max(maxAbnormalProb)))
    print("<<< [INFO] >>> Avg Abnormality Prob : {:.2f}".format(
        sum(maxAbnormalProb) / len(maxAbnormalProb)))
    print("<<< [INFO] >>> Max FPS achieved     : {:.1f}".format(maxFps))
    print("<<< [INFO] >>> Averge Estimated FPS : {:.1f}".format(estFps))
    print("<<< [INFO] >>> Total Infernece Time : {:.2f} seconds".format(
        execTime))
Exemplo n.º 18
0
def main():
    global args, best_prec1
    global crop_size
    args = parser.parse_args()

    num_class, train_list, val_list, args.root_path, prefix = dataset_config.return_dataset(
        args.dataset, args.modality)
    num_class = 1
    if args.train_list == "":
        args.train_list = train_list
    if args.val_list == "":
        args.val_list = val_list

    full_arch_name = args.arch
    if args.shift:
        full_arch_name += '_shift{}_{}'.format(args.shift_div,
                                               args.shift_place)
    if args.concat != "":
        full_arch_name += '_concat{}'.format(args.concat)
    if args.temporal_pool:
        full_arch_name += '_tpool'
    args.store_name = '_'.join([
        'TSM', args.dataset, args.modality, full_arch_name,
        args.consensus_type,
        'lr%.5f' % args.lr,
        'dropout%.2f' % args.dropout,
        'wd%.5f' % args.weight_decay,
        'batch%d' % args.batch_size,
        'segment%d' % args.num_segments, 'e{}'.format(args.epochs)
    ])
    if args.data_fuse:
        args.store_name += '_fuse'
    if args.extra_temporal_modeling:
        args.store_name += '_extra'
    if args.tune_from is not None:
        args.store_name += '_finetune'
    if args.pretrain != 'imagenet':
        args.store_name += '_{}'.format(args.pretrain)
    if args.lr_type != 'step':
        args.store_name += '_{}'.format(args.lr_type)
    if args.dense_sample:
        args.store_name += '_dense'
    if args.non_local > 0:
        args.store_name += '_nl'
    if args.clipnums:
        #pass
        args.store_name += "_clip{}".format(args.clipnums)
    if args.suffix is not None:
        args.store_name += '_{}'.format(args.suffix)
    print('storing name: ' + args.store_name)

    check_rootfolders()

    if args.prune in ['input', 'inout'] and args.tune_from:
        sd = torch.load(args.tune_from)
        sd = sd['state_dict']
        sd = input_dim_L2distance(sd, args.shift_div)

    model = TSN(
        num_class,
        args.num_segments,
        args.modality,
        base_model=args.arch,
        new_length=2 if args.data_fuse else None,
        consensus_type=args.consensus_type,
        dropout=args.dropout,
        img_feature_dim=args.img_feature_dim,
        partial_bn=not args.no_partialbn,
        pretrain=args.pretrain,
        is_shift=args.shift,
        shift_div=args.shift_div,
        shift_place=args.shift_place,
        fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
        temporal_pool=args.temporal_pool,
        non_local=args.non_local,
        concat=args.concat,
        extra_temporal_modeling=args.extra_temporal_modeling,
        prune_list=[prune_conv1in_list, prune_conv1out_list],
        is_prune=args.prune,
    )

    #model = torch.load("/home/ubuntu/backup_kevin/myownTSM_git/checkpoint/TSM_youcook_RGB_resnet50_shift8_blockres_concatAll_conv1d_lr0.00025_dropout0.70_wd0.00050_batch16_segment8_e20_finetune_slice_v1_clipnum500/ckpt_"+str(1)+".pth.tar")
    print(model)
    #summary(model, torch.zeros((16, 24, 224, 224)))
    #exit(1)
    if args.dataset == 'ucf101':  #twice sample & full resolution
        twice_sample = True
        crop_size = model.scale_size  #256 x 256
    else:
        twice_sample = False
        crop_size = model.crop_size  #224 x 224
    crop_size = 256
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies(args.concat)
    train_augmentation = model.get_augmentation(
        flip=False if 'something' in args.dataset or 'jester' in args.dataset
        or 'nvgesture' in args.dataset else True)

    model = torch.nn.DataParallel(model).cuda()

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.resume:
        if args.temporal_pool:  # early temporal pool so that we can load the state_dict
            make_temporal_pool(model.module.base_model, args.num_segments)
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    if args.tune_from:
        print(("=> fine-tuning from '{}'".format(args.tune_from)))
        tune_from_list = args.tune_from.split(',')
        sd = torch.load(tune_from_list[0])
        sd = sd['state_dict']

        model_dict = model.state_dict()
        replace_dict = []
        for k, v in sd.items():
            if k not in model_dict and k.replace('.net', '') in model_dict:
                print('=> Load after remove .net: ', k)
                replace_dict.append((k, k.replace('.net', '')))
        for k, v in model_dict.items():
            if k not in sd and k.replace('.net', '') in sd:
                print('=> Load after adding .net: ', k)
                replace_dict.append((k.replace('.net', ''), k))
        for k, v in model_dict.items():
            if k not in sd and k.replace('.prune', '') in sd:
                print('=> Load after adding .prune: ', k)
                replace_dict.append((k.replace('.prune', ''), k))

        if args.prune in ['input', 'inout']:
            sd = adjust_para_shape_prunein(sd, model_dict)
        if args.prune in ['output', 'inout']:
            sd = adjust_para_shape_pruneout(sd, model_dict)

        if args.concat != "" and "concat" not in tune_from_list[0]:
            sd = adjust_para_shape_concat(sd, model_dict)

        for k, k_new in replace_dict:
            sd[k_new] = sd.pop(k)
        keys1 = set(list(sd.keys()))
        keys2 = set(list(model_dict.keys()))
        set_diff = (keys1 - keys2) | (keys2 - keys1)
        print('#### Notice: keys that failed to load: {}'.format(set_diff))
        if args.dataset not in tune_from_list[0]:  # new dataset
            print('=> New dataset, do not load fc weights')
            sd = {k: v for k, v in sd.items() if 'fc' not in k}
        if args.modality != 'Flow' and 'Flow' in tune_from_list[0]:
            sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k}
        #print(sd.keys())
        #print("*"*50)
        #print(model_dict.keys())
        model_dict.update(sd)
        model.load_state_dict(model_dict)

    if args.temporal_pool and not args.resume:
        make_temporal_pool(model.module.base_model, args.num_segments)

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality in ['RGB']:
        data_length = 1
    elif args.modality in ['Depth']:
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5
    '''
	dataRoot = r"/home/share/YouCook/downloadVideo"
	for dirPath, dirnames, filenames in os.walk(dataRoot):
		for filename in filenames:
			print(os.path.join(dirPath,filename) +"is {}".format("exist" if os.path.isfile(os.path.join(dirPath,filename))else "NON"))
			train_data = torchvision.io.read_video(os.path.join(dirPath,filename),start_pts=0,end_pts=1001, )
			tmp = torchvision.io.read_video_timestamps(os.path.join(dirPath,filename),)
			print(tmp)
			print(len(tmp[0]))
			print(train_data[0].size())
			exit()
	exit()
	'''
    '''
	train_loader = torch.utils.data.DataLoader(
		TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments,
				   new_length=data_length,
				   modality=args.modality,
				   image_tmpl=prefix,
				   transform=torchvision.transforms.Compose([
					   train_augmentation,
					   Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
					   ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
					   normalize,
				   ]), dense_sample=args.dense_sample, data_fuse = args.data_fuse),
		batch_size=args.batch_size, shuffle=True,
		num_workers=args.workers, pin_memory=True,
		drop_last=True)  # prevent something not % n_GPU
	

	
	val_loader = torch.utils.data.DataLoader(
		TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments,
				   new_length=data_length,
				   modality=args.modality,
				   image_tmpl=prefix,
				   random_shift=False,
				   transform=torchvision.transforms.Compose([
					   GroupScale(int(scale_size)),
					   GroupCenterCrop(crop_size),
					   Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
					   ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
					   normalize,
				   ]), dense_sample=args.dense_sample, twice_sample=twice_sample, data_fuse = args.data_fuse),
		batch_size=args.batch_size, shuffle=False,
		num_workers=args.workers, pin_memory=True)
	'''
    #global trainDataloader, valDataloader, train_loader, val_loader
    trainDataloader = YouCookDataSetRcg(args.root_path, args.train_list,train=True,inputsize=crop_size,hasPreprocess = False,\
      clipnums=args.clipnums,
      hasWordIndex = True,)
    valDataloader = YouCookDataSetRcg(args.root_path, args.val_list,val=True,inputsize=crop_size,hasPreprocess = False,\
      clipnums=args.clipnums,
      hasWordIndex = True,)

    #print(trainDataloader._getMode())
    #print(valDataloader._getMode())
    #exit()
    train_loader = torch.utils.data.DataLoader(trainDataloader,
                                               #shuffle=True,
                                               )
    val_loader = torch.utils.data.DataLoader(valDataloader)
    #print(train_loader is val_loader)
    #print(trainDataloader._getMode())
    #print(valDataloader._getMode())

    #print(trainDataloader._getMode())
    #print(valDataloader._getMode())
    #print(len(train_loader))
    #exit()
    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    elif args.loss_type == "MSELoss":
        criterion = torch.nn.MSELoss().cuda()
    elif args.loss_type == "BCELoss":
        #print("BCELoss")
        criterion = torch.nn.BCELoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    #print(os.path.join(args.root_log, args.store_name, 'args.txt'))
    #exit()
    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)
        #print("265")
        # train for one epoch
        ######
        #print(trainDataloader._getMode())
        #print(valDataloader._getMode())
        train(train_loader, model, criterion, optimizer, epoch, log_training,
              tf_writer)
        ######
        #print("268")
        # evaluate on validation set
        #model = model.load_state_dict(torch.load("/home/ubuntu/backup_kevin/myownTSM_git/checkpoint/TSM_youcook_RGB_resnet50_shift8_blockres_concatAll_conv1d_lr0.00025_dropout0.70_wd0.00050_batch16_segment8_e20_finetune_slice_v1_clipnum500/ckpt_"+str(epoch+1)+".pth.tar"))
        #if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
        if False:
            prec1 = validate(val_loader, model, criterion, epoch, log_training,
                             tf_writer)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch)

            output_best = 'Best Prec@1: %.3f\n' % (best_prec1)
            #print(output_best)
            log_training.write(output_best + '\n')
            log_training.flush()

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
        else:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': best_prec1,
                }, False)
        #break
        print("test pass")
Exemplo n.º 19
0
        modality = 'RGB'
    else:
        modality = 'Flow'
    this_arch = this_weights.split('TSM_')[1].split('_')[2]
    modality_list.append(modality)
    num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset(
        args.dataset, modality)
    print('=> shift: {}, shift_div: {}, shift_place: {}'.format(
        is_shift, shift_div, shift_place))
    net = TSN(
        num_class,
        this_test_segments if is_shift else 1,
        modality,
        base_model=this_arch,
        consensus_type=args.crop_fusion_type,
        img_feature_dim=args.img_feature_dim,
        pretrain=args.pretrain,
        is_shift=is_shift,
        shift_div=shift_div,
        shift_place=shift_place,
        non_local='_nl' in this_weights,
    )

    if 'tpool' in this_weights:
        from ops.temporal_shift import make_temporal_pool
        make_temporal_pool(net.base_model,
                           this_test_segments)  # since DataParallel

    checkpoint = torch.load(this_weights)
    checkpoint = checkpoint['state_dict']
Exemplo n.º 20
0
total_num = None
for this_weights, this_test_segments, test_file in zip(weights_list,
                                                       test_segments_list,
                                                       test_file_list):
    has_tam, modality, backbone = parse_shift_option_from_log_name(
        this_weights)

    modality_list.append(modality)
    num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset(
        args.dataset, modality)
    print('=> TAM : {}, {} dense'.format(has_tam, args.sample))
    net = TSN(
        num_class,
        this_test_segments if has_tam else 1,
        modality,
        base_model=backbone,
        consensus_type=args.crop_fusion_type,
        img_feature_dim=args.img_feature_dim,
        tam=has_tam,
        non_local='_nl' in this_weights,
    )

    checkpoint = torch.load(this_weights, map_location='cpu')
    checkpoint = checkpoint['state_dict']
    base_dict = {}
    for k, v in list(checkpoint.items()):
        if k.startswith('module'):
            base_dict['.'.join(k.split('.')[1:])] = v
        else:
            base_dict[k] = v

    net.load_state_dict(base_dict)
Exemplo n.º 21
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset,
                                                                                                      args.modality)
    model = TSN(num_class, args.num_segments, args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place,
                fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
                temporal_pool=args.temporal_pool,
                non_local=args.non_local,
                cca3d = args.cca3d
                )

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation(flip=False if 'something' in args.dataset or 'jester' in args.dataset else True)

    model = model.cuda()

    if args.resume:
        if args.temporal_pool:  # early temporal pool so that we can load the state_dict
            make_temporal_pool(model.module.base_model, args.num_segments)
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume,map_location='cuda:0')
            parallel_state_dict = checkpoint['state_dict']
            cpu_state_dict={}
            for k,v in parallel_state_dict.items():
                cpu_state_dict[k[len('module.'):]] = v
            model.load_state_dict(cpu_state_dict)
            print(("=> loaded checkpoint '{}' (epoch {})".format(args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5
    preprocess=torchvision.transforms.Compose([
                       GroupScale(int(scale_size)),
                       GroupCenterCrop(crop_size),
                       Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                       ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
                       normalize,
                   ])
    val_loader = torch.utils.data.DataLoader(
        TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl=prefix,
                   random_shift=False,
                   transform = preprocess,
                   dense_sample=args.dense_sample),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=False)
    norm_param = (input_mean, input_std)
    cam_process(val_loader, model,norm_param)
Exemplo n.º 22
0
            rst = rst.reshape(
                (batch_size, -1, num_class)).mean(axis=1).reshape(
                    (batch_size, num_class))

        return rst


num_class, args.train_list, val_list, prefix = dataset_config.return_dataset(
    args.dataset, args.modality)

net = TSN(
    num_class,
    args.test_segments,
    args.modality,
    base_model=args.arch,
    consensus_type=args.crop_fusion_type,
    img_feature_dim=args.img_feature_dim,
    pretrain=args.pretrain,
    is_shift=args.shift,
    shift_div=args.shift_div,
    shift_place=args.shift_place,
)

# import pdb; pdb.set_trace()
'''
checkpoint = torch.load(args.weight)
checkpoint = checkpoint['state_dict']

base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())}
replace_dict = {
    'base_model.classifier.weight': 'new_fc.weight',
    'base_model.classifier.bias': 'new_fc.bias',
Exemplo n.º 23
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    #num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset,
    #                                                                                                  args.modality)
    num_class = 21
    args.train_list = "/home/jzwang/code/Video_3D/movienet/data/movie/movie_train.txt"
    args.val_list = "/home/jzwang/code/Video_3D/movienet/data/movie/movie_test.txt"
    args.root_path = ""
    prefix = "frame_{:04d}.jpg"
    full_arch_name = args.arch
    if args.shift:
        full_arch_name += '_shift{}_{}'.format(args.shift_div, args.shift_place)
    if args.temporal_pool:
        full_arch_name += '_tpool'
    args.store_name = '_'.join(
        ['TSM', args.dataset, args.modality, full_arch_name, args.consensus_type, 'segment%d' % args.num_segments,
         'e{}'.format(args.epochs)])
    if args.pretrain != 'imagenet':
        args.store_name += '_{}'.format(args.pretrain)
    if args.lr_type != 'step':
        args.store_name += '_{}'.format(args.lr_type)
    if args.dense_sample:
        args.store_name += '_dense'
    if args.non_local > 0:
        args.store_name += '_nl'
    if args.suffix is not None:
        args.store_name += '_{}'.format(args.suffix)
    print('storing name: ' + args.store_name)

    #check_rootfolders()

    model = TSN(num_class, args.num_segments, args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place,
                fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
                temporal_pool=args.temporal_pool,
                non_local=args.non_local)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation(flip=False if 'something' in args.dataset or 'jester' in args.dataset else True)

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.resume:
        if args.temporal_pool:  # early temporal pool so that we can load the state_dict
            make_temporal_pool(model.module.base_model, args.num_segments)
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            #best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(("=> loaded checkpoint '{}' (epoch {})"
                   .format(args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    if args.tune_from:
        print(("=> fine-tuning from '{}'".format(args.tune_from)))
        sd = torch.load(args.tune_from)
        sd = sd['state_dict']
        model_dict = model.state_dict()
        replace_dict = []
        for k, v in sd.items():
            if k not in model_dict and k.replace('.net', '') in model_dict:
                print('=> Load after remove .net: ', k)
                replace_dict.append((k, k.replace('.net', '')))
        for k, v in model_dict.items():
            if k not in sd and k.replace('.net', '') in sd:
                print('=> Load after adding .net: ', k)
                replace_dict.append((k.replace('.net', ''), k))

        for k, k_new in replace_dict:
            sd[k_new] = sd.pop(k)
        keys1 = set(list(sd.keys()))
        keys2 = set(list(model_dict.keys()))
        set_diff = (keys1 - keys2) | (keys2 - keys1)
        print('#### Notice: keys that failed to load: {}'.format(set_diff))
        if args.dataset not in args.tune_from:  # new dataset
            print('=> New dataset, do not load fc weights')
            sd = {k: v for k, v in sd.items() if 'fc' not in k}
        if args.modality == 'Flow' and 'Flow' not in args.tune_from:
            sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k}
        model_dict.update(sd)
        model.load_state_dict(model_dict)

    if args.temporal_pool and not args.resume:
        make_temporal_pool(model.module.base_model, args.num_segments)

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_loader = torch.utils.data.DataLoader(
        TSNDataSetMovie("", args.train_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl="frame_{:04d}.jpg" if args.modality in ["RGB", "RGBDiff"] else args.flow_prefix+"{}_{:05d}.jpg",
                   transform=torchvision.transforms.Compose([
                       train_augmentation,
                       Stack(roll=args.arch == 'BNInception'),
                       ToTorchFormatTensor(div=args.arch != 'BNInception'),
                       normalize,
                   ])),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        TSNDataSetMovie("", args.val_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl="frame_{:04d}.jpg" if args.modality in ["RGB", "RGBDiff"] else args.flow_prefix+"{}_{:05d}.jpg",
                   random_shift=False,
                   transform=torchvision.transforms.Compose([
                       GroupScale(int(scale_size)),
                       GroupCenterCrop(crop_size),
                       Stack(roll=args.arch == 'BNInception'),
                       ToTorchFormatTensor(div=args.arch != 'BNInception'),
                       normalize,
                   ])),
        batch_size=int(args.batch_size/2), shuffle=False,
        num_workers=args.workers, pin_memory=True)

    # define loss function (criterion) and optimizer
    criterion =  torch.nn.BCEWithLogitsLoss().cuda()
    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'], group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    zero_time = time.time()
    best_map = 0
    print ('Start training...')
    for epoch in range(args.start_epoch, args.epochs):
        valloss, mAP, wAP, output_mtx = validate(val_loader, model, criterion)
        adjust_learning_rate(optimizer, epoch, args.lr_steps)
        np.save("testnew.npy", output_mtx)
        print("saving down")
        # train for one epoch
        start_time = time.time()
        trainloss = train(train_loader, model, criterion, optimizer, epoch)

        print('Traing loss %4f Epoch %d'% (trainloss, epoch))
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            valloss, mAP, wAP, output_mtx = validate(val_loader, model, criterion)
            end_time = time.time()
            epoch_time = end_time - start_time
            total_time = end_time - zero_time
            print ('Total time used: %s Epoch %d time uesd: %s'%(
    				str(datetime.timedelta(seconds=int(total_time))),
    				epoch, str(datetime.timedelta(seconds=int(epoch_time)))))
            print ('Train loss: {0:.4f} val loss: {1:.4f} mAP: {2:.4f} wAP: {3:.4f}'.format(
    		   			trainloss, valloss, mAP, wAP))
            # evaluate on validation set
            is_best = mAP > best_map
            #if mAP > best_map:
                #best_map = mAP
    			# checkpoint_name = "%04d_%s" % (epoch+1, "checkpoint.pth.tar")
            checkpoint_name = "best_checkpoint.pth.tar"
            save_checkpoint({
			    'epoch': epoch+1,
			    'state_dict': model.state_dict(),
			    'optimizer': optimizer.state_dict(),
			    }, is_best, epoch)
            np.save("testnew.npy", output_mtx)
            print("saving down")
        with open(args.record_path, 'a') as file:
            file.write('Epoch:[{0}]'
		   		   'Train loss: {1:.4f} val loss: {2:.4f} map: {3:.4f}\n'.format(
		   			epoch+1, trainloss, valloss, mAP))


    print ('************ Done!... ************')
Exemplo n.º 24
0
data_iter_list = []
net_list = []
modality_list = args.modalities.split(',')
arch_list = args.archs.split('.')

total_num = None
for this_weights, this_test_segments, test_file, modality, this_arch in zip(
        weights_list, test_segments_list, test_file_list, modality_list,
        arch_list):
    num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset(
        args.dataset, modality)
    net = TSN(num_class,
              this_test_segments,
              modality,
              base_model=this_arch,
              consensus_type=args.crop_fusion_type,
              img_feature_dim=args.img_feature_dim,
              pretrain=args.pretrain)
    checkpoint = torch.load(this_weights)
    try:
        net.load_state_dict(checkpoint['state_dict'])
    except:
        checkpoint = checkpoint['state_dict']

        base_dict = {
            '.'.join(k.split('.')[1:]): v
            for k, v in list(checkpoint.items())
        }
        replace_dict = {
            'base_model.classifier.weight': 'new_fc.weight',
    def __init__(self,
                 checkpoint_file,
                 num_classes,
                 max_length=8,
                 trim_net=False,
                 checkpoint_is_model=False,
                 bottleneck_size=128):
        self.is_shift = None
        self.net = None
        self.arch = None
        self.num_classes = num_classes
        self.max_length = max_length
        self.bottleneck_size = bottleneck_size
        #self.feature_idx = feature_idx

        self.transform = None

        self.CNN_FEATURE_COUNT = [256, 512, 1024, 2048]

        # input variables
        this_test_segments = self.max_length
        test_file = None

        #model variables
        self.is_shift, shift_div, shift_place = True, 8, 'blockres'

        self.arch = 'resnet101'
        modality = 'RGB'

        # dataset variables
        num_class, train_list, val_list, root_path, prefix = dataset_config.return_dataset(
            'somethingv2', modality)
        print('=> shift: {}, shift_div: {}, shift_place: {}'.format(
            self.is_shift, shift_div, shift_place))

        # define model
        net = TSN(
            num_class,
            this_test_segments if self.is_shift else 1,
            modality,
            base_model=self.arch,
            consensus_type='avg',
            img_feature_dim=256,
            pretrain='imagenet',
            is_shift=self.is_shift,
            shift_div=shift_div,
            shift_place=shift_place,
            non_local='_nl' in checkpoint_file,
        )
        '''
        The checkpoint file appears to be an entire TSMBackBone Object. this needs to be
        handled acordingly. Either find a way to convert it back to a weights file or maniuplate it 
        to work with the system.
        '''

        # load checkpoint file
        checkpoint = torch.load(checkpoint_file)
        '''
        #include
        print("self.bottleneck_size:", self.bottleneck_size, type(self.bottleneck_size))
        net.base_model.avgpool = nn.Sequential(
            nn.Conv2d(2048, self.bottleneck_size, (1,1)),
            nn.ReLU(inplace=True),
            #nn.AdaptiveAvgPool2d(output_size=1)
        )

        if(not trim_net):
            print("no trim")
            net.new_fc = nn.Linear(self.bottleneck_size, 174)
        else:
            print("trim")
            net.consensus = nn.Identity()
            net.new_fc = nn.Identity()

        net.base_model.fc = nn.Identity() # sets the dropout value to None
        print(net) 
        
        # Combine network together so that the it can have parameters set correctly
        # I think, I'm not 100% what this code section actually does and I don't have 
        # the time to figure it out right now
        #print("checkpoint------------------------")
        #print(checkpoint)
		'''
        if (checkpoint_is_model):
            checkpoint = checkpoint.net.state_dict()
        else:
            checkpoint = checkpoint['state_dict']

        base_dict = {
            '.'.join(k.split('.')[1:]): v
            for k, v in list(checkpoint.items())
        }
        '''
        #include
        replace_dict = {'base_model.classifier.weight': 'new_fc.weight',
                        'base_model.classifier.bias': 'new_fc.bias',
                        }
        for k, v in replace_dict.items():
            if v in base_dict:
                base_dict.pop(v)
            if k in base_dict:
                base_dict.pop(k)
                #base_dict[v] = base_dict.pop(k)
		'''
        net.load_state_dict(base_dict, strict=False)

        # define image modifications
        self.transform = torchvision.transforms.Compose([
            torchvision.transforms.Compose([
                GroupScale(net.scale_size),
                GroupCenterCrop(net.scale_size),
            ]),
            #torchvision.transforms.Compose([ GroupFullResSample(net.scale_size, net.scale_size, flip=False) ]),
            Stack(roll=(self.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(self.arch not in ['BNInception', 'InceptionV3'])),
            GroupNormalize(net.input_mean, net.input_std),
        ])

        # place net onto GPU and finalize network
        self.model = net
        net = torch.nn.DataParallel(net.cuda())
        net.eval()

        # network variable
        self.net = net

        # loss variable (used for generating gradients when ranking)
        if (not trim_net):
            self.loss = torch.nn.CrossEntropyLoss().cuda()
Exemplo n.º 26
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(
        args.dataset, args.modality)
    full_arch_name = args.arch
    if args.shift:
        full_arch_name += '_shift{}_{}'.format(args.shift_div,
                                               args.shift_place)
    if args.temporal_pool:
        full_arch_name += '_tpool'
    args.store_name = '_'.join([
        'TSM', args.dataset, args.modality, full_arch_name,
        args.consensus_type,
        'segment%d' % args.num_segments, 'e{}'.format(args.epochs)
    ])
    if args.pretrain != 'imagenet':
        args.store_name += '_{}'.format(args.pretrain)
    if args.lr_type != 'step':
        args.store_name += '_{}'.format(args.lr_type)
    if args.dense_sample:
        args.store_name += '_dense'
    if args.non_local > 0:
        args.store_name += '_nl'
    if args.suffix is not None:
        args.store_name += '_{}'.format(args.suffix)
    print('storing name: ' + args.store_name)

    check_rootfolders()

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                is_shift=args.shift,
                shift_div=args.shift_div,
                shift_place=args.shift_place,
                fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
                temporal_pool=args.temporal_pool,
                non_local=args.non_local)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation(
        flip=False
        if 'something' in args.dataset or 'jester' in args.dataset else True)

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.resume:
        if args.temporal_pool:  # early temporal pool so that we can load the state_dict
            make_temporal_pool(model.module.base_model, args.num_segments)
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    if args.tune_from:
        print(("=> fine-tuning from '{}'".format(args.tune_from)))
        sd = torch.load(args.tune_from)
        sd = sd['state_dict']
        model_dict = model.state_dict()
        replace_dict = []
        for k, v in sd.items():
            if k not in model_dict and k.replace('.net', '') in model_dict:
                print('=> Load after remove .net: ', k)
                replace_dict.append((k, k.replace('.net', '')))
        for k, v in model_dict.items():
            if k not in sd and k.replace('.net', '') in sd:
                print('=> Load after adding .net: ', k)
                replace_dict.append((k.replace('.net', ''), k))

        for k, k_new in replace_dict:
            sd[k_new] = sd.pop(k)
        keys1 = set(list(sd.keys()))
        keys2 = set(list(model_dict.keys()))
        set_diff = (keys1 - keys2) | (keys2 - keys1)
        print('#### Notice: keys that failed to load: {}'.format(set_diff))
        if args.dataset not in args.tune_from:  # new dataset
            print('=> New dataset, do not load fc weights')
            sd = {k: v for k, v in sd.items() if 'fc' not in k}
        if args.modality == 'Flow' and 'Flow' not in args.tune_from:
            sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k}
        model_dict.update(sd)
        model.load_state_dict(model_dict)

    if args.temporal_pool and not args.resume:
        make_temporal_pool(model.module.base_model, args.num_segments)

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_loader = torch.utils.data.DataLoader(
        TSNDataSet(
            args.root_path,
            args.train_list,
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            image_tmpl=prefix,
            transform=torchvision.transforms.Compose([
                train_augmentation,
                Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                ToTorchFormatTensor(
                    div=(args.arch not in ['BNInception', 'InceptionV3'])),
                normalize,
            ]),
            dense_sample=args.dense_sample),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=True)  # prevent something not % n_GPU

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ]),
        dense_sample=args.dense_sample),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training,
              tf_writer)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion, epoch, log_training,
                             tf_writer)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch)

            output_best = 'Best Prec@1: %.3f\n' % (best_prec1)
            print(output_best)
            log_training.write(output_best + '\n')
            log_training.flush()

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
Exemplo n.º 27
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    num_class = opts.num_class

    full_arch_name = args.arch
    if args.shift:
        full_arch_name += '_shift{}_{}'.format(args.shift_div,
                                               args.shift_place)
    if args.temporal_pool:
        full_arch_name += '_tpool'
    args.store_name = '_'.join([
        'TSM', args.dataset, args.modality, full_arch_name,
        args.consensus_type,
        'segment%d' % args.num_segments, 'e{}'.format(args.epochs)
    ])
    if args.pretrain != 'imagenet':
        args.store_name += '_{}'.format(args.pretrain)
    if args.lr_type != 'step':
        args.store_name += '_{}'.format(args.lr_type)
    if args.non_local > 0:
        args.store_name += '_nl'
    print('storing name: ' + args.store_name)

    # check_rootfolders()

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_size=args.img_size,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                is_shift=args.shift,
                shift_div=args.shift_div,
                shift_place=args.shift_place,
                fc_lr5=True,
                temporal_pool=args.temporal_pool,
                non_local=args.non_local)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation(
        flip=False
        if 'something' in args.dataset or 'jester' in args.dataset else True)

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    model.apply(weights_init)

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.resume:
        if args.temporal_pool:  # early temporal pool so that we can load the state_dict
            make_temporal_pool(model.module.base_model, args.num_segments)
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            # best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'], strict=False)
            # optimizer.load_state_dict(checkpoint['optimizer'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    if args.temporal_pool and not args.resume:
        make_temporal_pool(model.module.base_model, args.num_segments)

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    fr_r = open(opts.NUM_LABEL_R, 'r+')
    w2n = eval(fr_r.read())
    fr_r.close()

    fr = open(opts.NUM_LABEL, 'r+')
    n2w = eval(fr.read())
    fr.close()

    train_loader, val_loader, test_loader = None, None, None
    if args.mode != 'test':
        lip_dict, video_list = opts.file_deal(opts.TRAIN_DATA, w2n)
        train_num = int(len(video_list) * 0.95)
        train_loader = torch.utils.data.DataLoader(TSNDataSet(
            opts.TRAIN_DATA,
            args.mode,
            num_segments=args.num_segments,
            img_size=args.img_size,
            lip_dict=lip_dict,
            video_list=video_list[:train_num],
            transform=torchvision.transforms.Compose([
                train_augmentation,
                GroupMultiScaleCrop(args.img_size, [1, .875, .75, .66]),
                Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                ToTorchFormatTensor(
                    div=(args.arch not in ['BNInception', 'InceptionV3'])),
                normalize,
            ])),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True)

        val_loader = torch.utils.data.DataLoader(TSNDataSet(
            opts.TRAIN_DATA,
            args.mode,
            num_segments=args.num_segments,
            img_size=args.img_size,
            lip_dict=lip_dict,
            video_list=video_list[train_num:],
            transform=torchvision.transforms.Compose([
                GroupScale(int(scale_size)),
                GroupCenterCrop(crop_size),
                Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                ToTorchFormatTensor(
                    div=(args.arch not in ['BNInception', 'InceptionV3'])),
                normalize,
            ])),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
    else:
        lip_dict, video_list = opts.file_deal(opts.TEST_DATA, w2n)
        test_loader = torch.utils.data.DataLoader(TSNDataSet_infer(
            opts.TEST_DATA,
            num_segments=args.num_segments,
            img_size=args.img_size,
            lip_dict=lip_dict,
            video_list=video_list,
            transform=torchvision.transforms.Compose([
                GroupScale(int(scale_size)),
                GroupCenterCrop(crop_size),
                Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                ToTorchFormatTensor(
                    div=(args.arch not in ['BNInception', 'InceptionV3'])),
                normalize,
            ])),
                                                  batch_size=args.batch_size,
                                                  shuffle=False,
                                                  num_workers=args.workers,
                                                  pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    if args.mode == 'test':
        if args.sub == 'sub':
            inference(test_loader, model, n2w)
        else:
            inferencefusion(test_loader, model, n2w)
        return

    # 开始训练
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion, epoch, n2w)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)

            output_best = 'Best Prec@1: %.3f\n' % (best_prec1)
            print(output_best)

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)