Example #1
0
def get_data_loaders(model, prefix, args):
    train_transform_flip = torchvision.transforms.Compose([
        model.module.get_augmentation(flip=True),
        Stack(roll=("BNInc" in args.arch)),
        ToTorchFormatTensor(div=("BNInc" not in args.arch)),
        GroupNormalize(model.module.input_mean, model.module.input_std),
    ])

    train_transform_nofl = torchvision.transforms.Compose([
        model.module.get_augmentation(flip=False),
        Stack(roll=("BNInc" in args.arch)),
        ToTorchFormatTensor(div=("BNInc" not in args.arch)),
        GroupNormalize(model.module.input_mean, model.module.input_std),
    ])

    val_transform = torchvision.transforms.Compose([
        GroupScale(int(model.module.scale_size)),
        GroupCenterCrop(model.module.crop_size),
        Stack(roll=("BNInc" in args.arch)),
        ToTorchFormatTensor(div=("BNInc" not in args.arch)),
        GroupNormalize(model.module.input_mean, model.module.input_std),
    ])

    train_dataset = TSNDataSet(args.root_path,
                               args.train_list,
                               num_segments=args.num_segments,
                               image_tmpl=prefix,
                               transform=(train_transform_flip,
                                          train_transform_nofl),
                               dense_sample=args.dense_sample,
                               dataset=args.dataset,
                               filelist_suffix=args.filelist_suffix,
                               folder_suffix=args.folder_suffix,
                               save_meta=args.save_meta,
                               always_flip=args.always_flip,
                               conditional_flip=args.conditional_flip,
                               adaptive_flip=args.adaptive_flip)

    val_dataset = TSNDataSet(args.root_path,
                             args.val_list,
                             num_segments=args.num_segments,
                             image_tmpl=prefix,
                             random_shift=False,
                             transform=(val_transform, val_transform),
                             dense_sample=args.dense_sample,
                             dataset=args.dataset,
                             filelist_suffix=args.filelist_suffix,
                             folder_suffix=args.folder_suffix,
                             save_meta=args.save_meta)

    train_loader = build_dataflow(train_dataset, True, args.batch_size,
                                  args.workers, args.not_pin_memory)
    val_loader = build_dataflow(val_dataset, False, args.batch_size,
                                args.workers, args.not_pin_memory)

    return train_loader, val_loader
def get_val_loader(model):
    root_path = '/home/mbc2004/datasets/Something-Something/frames/'
    train_list = '/home/mbc2004/datasets/Something-Something/annotations/val_videofolder.txt'
    num_segments = 8
    modality = 'RGB'
    dense_sample = False
    batch_size = 8  #64
    workers = 16
    arch = 'resnet50'

    prefix = '{:06d}.jpg'

    print('#' * 20, 'NO FLIP!!!')
    train_augmentation = torchvision.transforms.Compose(
        [GroupMultiScaleCrop(224, [1, .875, .75, .66])])

    return torch.utils.data.DataLoader(
        TSNDataSet(root_path,
                   train_list,
                   num_segments=num_segments,
                   new_length=1,
                   modality=modality,
                   image_tmpl=prefix,
                   transform=torchvision.transforms.Compose([
                       train_augmentation,
                       Stack(roll=(arch in ['BNInception', 'InceptionV3'])),
                       ToTorchFormatTensor(
                           div=(arch not in ['BNInception', 'InceptionV3'])),
                       IdentityTransform(),
                   ]),
                   dense_sample=dense_sample),
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers,
        pin_memory=True,
        drop_last=True)  # prevent something not % n_GPU
        test_file = test_file if test_file is not None else val_list

        data_loader = torch.utils.data.DataLoader(
            TSNDataSet(
                root_path,
                test_file,
                num_segments=this_test_segments,
                new_length=1
                if modality in ['RGB', 'RGB-flo', 'RGB-seg'] else 5,
                modality=modality,
                image_tmpl=prefix,
                test_mode=True,
                remove_missing=len(weights_list) == 1,
                transform=torchvision.transforms.Compose([
                    cropping,
                    Stack(roll=(this_arch in ['BNInception', 'InceptionV3']),
                          mask=(modality in ['RGB-flo', 'RGB-seg'])),
                    ToTorchFormatTensor(
                        div=(this_arch not in ['BNInception', 'InceptionV3'])),
                    GroupNormalize(net.input_mean, net.input_std),
                ]),
                dense_sample=args.dense_sample,
                twice_sample=args.twice_sample,
                dense_window=args.dense_window,
                full_sample=args.full_sample,
                ipn=args.dataset == 'ipn',
                ipn_no_class=ipn_no_class),
            batch_size=args.batch_size,
            shuffle=False,
            # num_workers=args.workers, pin_memory=True,
        )
Example #4
0
def main():
    global args, best_prec1, TRAIN_SAMPLES
    args = parser.parse_args()

    num_class, args.train_list, args.val_list, args.test_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset,
                                                                                                      args.modality)
    if os.path.exists(args.test_list):
        args.val_list = args.test_list


    model = TSN(num_class, args.num_segments, args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place,
                fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
                temporal_pool=args.temporal_pool,
                non_local=args.non_local,
                tin=args.tin)


    crop_size = args.crop_size
    scale_size = args.scale_size
    input_mean = [0.485, 0.456, 0.406]
    input_std = [0.229, 0.224, 0.225]

    print(args.gpus)
    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    if os.path.isfile(args.resume_path):
        print(("=> loading checkpoint '{}'".format(args.resume_path)))
        checkpoint = torch.load(args.resume_path)
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        print(("=> loaded checkpoint '{}' (epoch {})"
               .format(args.evaluate, checkpoint['epoch'])))
    else:
        print(("=> no checkpoint found at '{}'".format(args.resume_path)))


    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    if args.random_crops == 1:
       crop_aug = GroupCenterCrop(args.crop_size)
    elif args.random_crops == 3:
       crop_aug = GroupFullResSample(args.crop_size, args.scale_size, flip=False)
    elif args.random_crops == 5:
       crop_aug = GroupOverSample(args.crop_size, args.scale_size, flip=False)
    else:
       crop_aug = MultiGroupRandomCrop(args.crop_size, args.random_crops),


    test_dataset = TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl=prefix,
                   multi_class=args.multi_class,
                   transform=torchvision.transforms.Compose([
                       GroupScale(int(args.scale_size)),
                       crop_aug,
                       Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                       ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
                       normalize,
                   ]), dense_sample=args.dense_sample,
                   test_mode=True,
                   temporal_clips=args.temporal_clips)


    test_loader = torch.utils.data.DataLoader(
            test_dataset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    test(test_loader, model, args.start_epoch)
Example #5
0
def main():
    t_start = time.time()
    global args, best_prec1, num_class, use_ada_framework  # , model
    wandb.init(
        project="arnet-reproduce",
        name=args.exp_header,
        entity="video_channel"
    )
    wandb.config.update(args)
    set_random_seed(args.random_seed)
    use_ada_framework = args.ada_reso_skip and args.offline_lstm_last == False and args.offline_lstm_all == False and args.real_scsampler == False

    if args.ablation:
        logger = None
    else:
        if not test_mode:
            logger = Logger()
            sys.stdout = logger
        else:
            logger = None

    num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset,
                                                                                                      args.data_dir)
    
    #===
    #args.val_list = args.train_list
    #===

    if args.ada_reso_skip:
        if len(args.ada_crop_list) == 0:
            args.ada_crop_list = [1 for _ in args.reso_list]

    if use_ada_framework:
        init_gflops_table()

    model = TSN_Ada(num_class, args.num_segments,
                    base_model=args.arch,
                    consensus_type=args.consensus_type,
                    dropout=args.dropout,
                    partial_bn=not args.no_partialbn,
                    pretrain=args.pretrain,
                    fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
                    args=args)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation(
        flip=False if 'something' in args.dataset or 'jester' in args.dataset else True)

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()
    # TODO(yue) freeze some params in the policy + lstm layers
    if args.freeze_policy:
        for name, param in model.module.named_parameters():
            if "lite_fc" in name or "lite_backbone" in name or "rnn" in name or "linear" in name:
                param.requires_grad = False

    if args.freeze_backbone:
        for name, param in model.module.named_parameters():
            if "base_model" in name:
                param.requires_grad = False
    if len(args.frozen_list) > 0:
        for name, param in model.module.named_parameters():
            for keyword in args.frozen_list:
                if keyword[0] == "*":
                    if keyword[-1] == "*":  # TODO middle
                        if keyword[1:-1] in name:
                            param.requires_grad = False
                            print(keyword, "->", name, "frozen")
                    else:  # TODO suffix
                        if name.endswith(keyword[1:]):
                            param.requires_grad = False
                            print(keyword, "->", name, "frozen")
                elif keyword[-1] == "*":  # TODO prefix
                    if name.startswith(keyword[:-1]):
                        param.requires_grad = False
                        print(keyword, "->", name, "frozen")
                else:  # TODO exact word
                    if name == keyword:
                        param.requires_grad = False
                        print(keyword, "->", name, "frozen")
        print("=" * 80)
        for name, param in model.module.named_parameters():
            print(param.requires_grad, "\t", name)

        print("=" * 80)
        for name, param in model.module.named_parameters():
            print(param.requires_grad, "\t", name)

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(("=> loaded checkpoint '{}' (epoch {})"
                   .format(args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    if args.tune_from:
        print(("=> fine-tuning from '{}'".format(args.tune_from)))
        sd = torch.load(args.tune_from)
        sd = sd['state_dict']
        model_dict = model.state_dict()
        replace_dict = []
        for k, v in sd.items():
            if k not in model_dict and k.replace('.net', '') in model_dict:
                print('=> Load after remove .net: ', k)
                replace_dict.append((k, k.replace('.net', '')))
        for k, v in model_dict.items():
            if k not in sd and k.replace('.net', '') in sd:
                print('=> Load after adding .net: ', k)
                replace_dict.append((k.replace('.net', ''), k))

        for k, k_new in replace_dict:
            sd[k_new] = sd.pop(k)
        keys1 = set(list(sd.keys()))
        keys2 = set(list(model_dict.keys()))
        set_diff = (keys1 - keys2) | (keys2 - keys1)
        print('#### Notice: keys that failed to load: {}'.format(set_diff))
        if args.dataset not in args.tune_from:  # new dataset
            print('=> New dataset, do not load fc weights')
            sd = {k: v for k, v in sd.items() if 'fc' not in k}
        model_dict.update(sd)
        model.load_state_dict(model_dict)

    # TODO(yue) ada_model loading process
    if args.ada_reso_skip:
        if test_mode:
            print("Test mode load from pretrained model")
            the_model_path = args.test_from
            if ".pth.tar" not in the_model_path:
                the_model_path = ospj(the_model_path, "models", "ckpt.best.pth.tar")
            model_dict = model.state_dict()
            sd = load_to_sd(model_dict, the_model_path, "foo", "bar", -1, apple_to_apple=True)
            model_dict.update(sd)
            model.load_state_dict(model_dict)
        elif args.base_pretrained_from != "":
            print("Adaptively load from pretrained whole")
            model_dict = model.state_dict()
            sd = load_to_sd(model_dict, args.base_pretrained_from, "foo", "bar", -1, apple_to_apple=True)

            model_dict.update(sd)
            model.load_state_dict(model_dict)

        elif len(args.model_paths) != 0:
            print("Adaptively load from model_path_list")
            model_dict = model.state_dict()
            # TODO(yue) policy net
            sd = load_to_sd(model_dict, args.policy_path, "lite_backbone", "lite_fc",
                            args.reso_list[args.policy_input_offset])
            model_dict.update(sd)
            # TODO(yue) backbones
            for i, tmp_path in enumerate(args.model_paths):
                base_model_index = i
                new_i = i

                sd = load_to_sd(model_dict, tmp_path, "base_model_list.%d" % base_model_index, "new_fc_list.%d" % new_i,
                                args.reso_list[i])
                model_dict.update(sd)
            model.load_state_dict(model_dict)
    else:
        if test_mode:
            the_model_path = args.test_from
            if ".pth.tar" not in the_model_path:
                the_model_path = ospj(the_model_path, "models", "ckpt.best.pth.tar")
            model_dict = model.state_dict()
            sd = load_to_sd(model_dict, the_model_path, "foo", "bar", -1, apple_to_apple=True)
            model_dict.update(sd)
            model.load_state_dict(model_dict)

    if args.ada_reso_skip == False and args.base_pretrained_from != "":
        print("Baseline: load from pretrained model")
        model_dict = model.state_dict()
        sd = load_to_sd(model_dict, args.base_pretrained_from, "base_model", "new_fc", 224)

        if args.ignore_new_fc_weight:
            print("@ IGNORE NEW FC WEIGHT !!!")
            del sd["module.new_fc.weight"]
            del sd["module.new_fc.bias"]

        model_dict.update(sd)
        model.load_state_dict(model_dict)

    cudnn.benchmark = True

    # Data loading code
    normalize = GroupNormalize(input_mean, input_std)
    data_length = 1
    train_loader = torch.utils.data.DataLoader(
        TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments,
                   image_tmpl=prefix,
                   transform=torchvision.transforms.Compose([
                       train_augmentation,
                       Stack(roll=False),
                       ToTorchFormatTensor(div=True),
                       normalize,
                   ]), dense_sample=args.dense_sample,
                   dataset=args.dataset,
                   partial_fcvid_eval=args.partial_fcvid_eval,
                   partial_ratio=args.partial_ratio,
                   ada_reso_skip=args.ada_reso_skip,
                   reso_list=args.reso_list,
                   random_crop=args.random_crop,
                   center_crop=args.center_crop,
                   ada_crop_list=args.ada_crop_list,
                   rescale_to=args.rescale_to,
                   policy_input_offset=args.policy_input_offset,
                   save_meta=args.save_meta),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True,
        drop_last=True)  # prevent something not % n_GPU

    val_loader = torch.utils.data.DataLoader(
        TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments,
                   image_tmpl=prefix,
                   random_shift=False,
                   transform=torchvision.transforms.Compose([
                       GroupScale(int(scale_size)),
                       GroupCenterCrop(crop_size),
                       Stack(roll=False),
                       ToTorchFormatTensor(div=True),
                       normalize,
                   ]), dense_sample=args.dense_sample,
                   dataset=args.dataset,
                   partial_fcvid_eval=args.partial_fcvid_eval,
                   partial_ratio=args.partial_ratio,
                   ada_reso_skip=args.ada_reso_skip,
                   reso_list=args.reso_list,
                   random_crop=args.random_crop,
                   center_crop=args.center_crop,
                   ada_crop_list=args.ada_crop_list,
                   rescale_to=args.rescale_to,
                   policy_input_offset=args.policy_input_offset,
                   save_meta=args.save_meta
                   ),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss().cuda()

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    if not test_mode:
        exp_full_path = setup_log_directory(logger, args.log_dir, args.exp_header)
    else:
        exp_full_path = None

    if not args.ablation:
        if not test_mode:
            with open(os.path.join(exp_full_path, 'args.txt'), 'w') as f:
                f.write(str(args))
            tf_writer = SummaryWriter(log_dir=exp_full_path)
        else:
            tf_writer = None
    else:
        tf_writer = None

    # TODO(yue)
    map_record = Recorder()
    mmap_record = Recorder()
    prec_record = Recorder()
    best_train_usage_str = None
    best_val_usage_str = None

    wandb.watch(model)
    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        if not args.skip_training:
            set_random_seed(args.random_seed + epoch)
            adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)
            train_usage_str = train(train_loader, model, criterion, optimizer, epoch, logger, exp_full_path, tf_writer)
        else:
            train_usage_str = "No training usage stats (Eval Mode)"

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            set_random_seed(args.random_seed)
            mAP, mmAP, prec1, val_usage_str, val_gflops = validate(val_loader, model, criterion, epoch, logger,
                                                                   exp_full_path, tf_writer)

            # remember best prec@1 and save checkpoint
            map_record.update(mAP)
            mmap_record.update(mmAP)
            prec_record.update(prec1)

            if mmap_record.is_current_best():
                best_train_usage_str = train_usage_str
                best_val_usage_str = val_usage_str

            print('Best mAP: %.3f (epoch=%d)\t\tBest mmAP: %.3f(epoch=%d)\t\tBest Prec@1: %.3f (epoch=%d)' % (
                map_record.best_val, map_record.best_at,
                mmap_record.best_val, mmap_record.best_at,
                prec_record.best_val, prec_record.best_at))

            if args.skip_training:
                break

            if (not args.ablation) and (not test_mode):
                tf_writer.add_scalar('acc/test_top1_best', prec_record.best_val, epoch)
                save_checkpoint({
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': prec_record.best_val,
                }, mmap_record.is_current_best(), exp_full_path)

    if use_ada_framework and not test_mode:
        print("Best train usage:")
        print(best_train_usage_str)
        print()
        print("Best val usage:")
        print(best_val_usage_str)

    print("Finished in %.4f seconds\n" % (time.time() - t_start))
Example #6
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    print("------------------------------------")
    print("Environment Versions:")
    print("- Python: {}".format(sys.version))
    print("- PyTorch: {}".format(torch.__version__))
    print("- TorchVison: {}".format(torchvision.__version__))

    args_dict = args.__dict__
    print("------------------------------------")
    print(args.arch + " Configurations:")
    for key in args_dict.keys():
        print("- {}: {}".format(key, args_dict[key]))
    print("------------------------------------")
    print(args.mode)
    if args.dataset == 'ucf101':
        num_class = 101
        rgb_read_format = "{:05d}.jpg"
    elif args.dataset == 'hmdb51':
        num_class = 51
        rgb_read_format = "{:05d}.jpg"
    elif args.dataset == 'kinetics':
        num_class = 400
        rgb_read_format = "{:05d}.jpg"
    elif args.dataset == 'something':
        num_class = 174
        rgb_read_format = "{:05d}.jpg"
    elif args.dataset == 'somethingv2':
        num_class = 174
        rgb_read_format = "img_{:05d}.jpg"
    elif args.dataset == 'NTU_RGBD':
        num_class = 120
        rgb_read_format = "{:05d}.jpg"
    elif args.dataset == 'tinykinetics':
        num_class = 150
        rgb_read_format = "{:05d}.jpg"
    else:
        raise ValueError('Unknown dataset ' + args.dataset)

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                partial_bn=not args.no_partialbn,
                non_local=args.non_local)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    # Optimizer s also support specifying per-parameter options.
    # To do this, pass in an iterable of dict s.
    # Each of them will define a separate parameter group,
    # and should contain a params key, containing a list of parameters belonging to it.
    # Other keys should match the keyword arguments accepted by the optimizers,
    # and will be used as optimization options for this group.
    policies = model.get_optim_policies(args.dataset)

    train_augmentation = model.get_augmentation()

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    model_dict = model.state_dict()

    if args.arch == "resnet50":
        new_state_dict = {}  #model_dict
        div = False
        roll = True
    elif args.arch == "resnet34":
        pretrained_dict = {}
        new_state_dict = {}  #model_dict
        for k, v in model_dict.items():
            if ('fc' not in k):
                new_state_dict.update({k: v})
        div = False
        roll = True
    elif (args.arch[:3] == "TCM"):
        pretrained_dict = {}
        new_state_dict = {}  #model_dict
        for k, v in model_dict.items():
            if ('fc' not in k):
                new_state_dict.update({k: v})
        div = True
        roll = False

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 1

    train_loader = torch.utils.data.DataLoader(
        TSNDataSet(
            "",
            args.train_list,
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            mode=args.mode,
            image_tmpl=args.rgb_prefix + rgb_read_format if args.modality
            in ["RGB", "RGBDiff"] else args.flow_prefix + rgb_read_format,
            img_start_idx=args.img_start_idx,
            transform=torchvision.transforms.Compose([
                GroupScale((240, 320)),
                #                        GroupScale(int(scale_size)),
                train_augmentation,
                Stack(roll=roll),
                ToTorchFormatTensor(div=div),
                normalize,
            ])),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        TSNDataSet(
            "",
            args.val_list,
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            mode=args.mode,
            image_tmpl=args.rgb_prefix + rgb_read_format if args.modality
            in ["RGB", "RGBDiff"] else args.flow_prefix + rgb_read_format,
            img_start_idx=args.img_start_idx,
            random_shift=False,
            transform=torchvision.transforms.Compose([
                GroupScale((240, 320)),
                #                        GroupScale((224)),
                #                        GroupScale(int(scale_size)),
                GroupCenterCrop(crop_size),
                Stack(roll=roll),
                ToTorchFormatTensor(div=div),
                normalize,
            ])),
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()

    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=args.nesterov)

    output_list = []
    if args.evaluate:
        prec1, score_tensor = validate(val_loader,
                                       model,
                                       criterion,
                                       temperature=100)
        output_list.append(score_tensor)
        save_validation_score(output_list, filename='score.pt')
        print("validation score saved in {}".format('/'.join(
            (args.val_output_folder, 'score_inf5.pt'))))
        return

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)
        # train for one epoch
        temperature = train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1, score_tensor = validate(val_loader,
                                           model,
                                           criterion,
                                           temperature=temperature)

            output_list.append(score_tensor)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)

            output_best = 'Best Prec@1: %.3f\n' % (best_prec1)
            print(output_best)

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)

    # save validation score
    save_validation_score(output_list)
    print("validation score saved in {}".format('/'.join(
        (args.val_output_folder, 'score.pt'))))
        raise ValueError(
            "Only 1, 5, 10 crops are supported while we got {}".format(
                args.test_crops))

    data_loader = torch.utils.data.DataLoader(
        TSNDataSet(
            root_path, [test_file] if test_file is not None else val_list,
            num_segments=this_test_segments,
            new_length=[new_length],
            modality=[modality],
            image_tmpl=prefix,
            test_mode=True,
            remove_missing=len(weights_list) == 1,
            transform=torchvision.transforms.Compose([
                cropping,
                Stack(roll=(this_arch in ['BNInception', 'InceptionV3'])),
                ToTorchFormatTensor(
                    div=(this_arch not in ['BNInception', 'InceptionV3'])),
                GroupNormalize(net.input_mean, net.input_std)
                if modality != 'RGBDiff' else IdentityTransform(),
            ]),
            dense_sample=args.dense_sample,
            dense_length=args.dense_length,
            dense_number=args.dense_number,
            twice_sample=args.twice_sample,
            random_sample=args.random_sample),
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True,
    )
def main():
        

    # options
    parser = argparse.ArgumentParser(description="TSM testing on the full validation set")
    parser.add_argument('dataset', type=str)

    # may contain splits
    parser.add_argument('--weights', type=str, default=None)
    parser.add_argument('--test_segments', type=str, default=25)
    parser.add_argument('--dense_sample', default=False, action="store_true", help='use dense sample as I3D')
    parser.add_argument('--twice_sample', default=False, action="store_true", help='use twice sample for ensemble')
    parser.add_argument('--full_res', default=False, action="store_true",
                        help='use full resolution 256x256 for test as in Non-local I3D')

    parser.add_argument('--test_crops', type=int, default=1)
    parser.add_argument('--coeff', type=str, default=None)
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
                        help='number of data loading workers (default: 8)')

    # for true test
    parser.add_argument('--test_list', type=str, default=None)
    parser.add_argument('--csv_file', type=str, default=None)

    parser.add_argument('--softmax', default=False, action="store_true", help='use softmax')

    parser.add_argument('--max_num', type=int, default=-1)
    parser.add_argument('--input_size', type=int, default=224)
    parser.add_argument('--crop_fusion_type', type=str, default='avg')
    parser.add_argument('--gpus', nargs='+', type=int, default=None)
    parser.add_argument('--img_feature_dim',type=int, default=256)
    parser.add_argument('--num_set_segments',type=int, default=1,help='TODO: select multiply set of n-frames from a video')
    parser.add_argument('--pretrain', type=str, default='imagenet')

    args = parser.parse_args()





    def accuracy(output, target, topk=(1,)):
        """Computes the precision@k for the specified values of k"""
        maxk = max(topk)
        batch_size = target.size(0)
        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


    def parse_shift_option_from_log_name(log_name):
        if 'shift' in log_name:
            strings = log_name.split('_')
            for i, s in enumerate(strings):
                if 'shift' in s:
                    break
            return True, int(strings[i].replace('shift', '')), strings[i + 1]
        else:
            return False, None, None


    weights_list = args.weights.split(',')
    test_segments_list = [int(s) for s in args.test_segments.split(',')]
    assert len(weights_list) == len(test_segments_list)
    if args.coeff is None:
        coeff_list = [1] * len(weights_list)
    else:
        coeff_list = [float(c) for c in args.coeff.split(',')]

    if args.test_list is not None:
        test_file_list = args.test_list.split(',')
    else:
        test_file_list = [None] * len(weights_list)


    data_iter_list = []
    net_list = []
    modality_list = []

    total_num = None
    for this_weights, this_test_segments, test_file in zip(weights_list, test_segments_list, test_file_list):
        is_shift, shift_div, shift_place = parse_shift_option_from_log_name(this_weights)
        if 'RGB' in this_weights:
            modality = 'RGB'
        else:
            modality = 'Flow'
        this_arch = this_weights.split('TSM_')[1].split('_')[2]
        modality_list.append(modality)
        num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset(args.dataset,
                                                                                                modality)
        print('=> shift: {}, shift_div: {}, shift_place: {}'.format(is_shift, shift_div, shift_place))
        net = TSN(num_class, this_test_segments if is_shift else 1, modality,
                base_model=this_arch,
                consensus_type=args.crop_fusion_type,
                img_feature_dim=args.img_feature_dim,
                pretrain=args.pretrain,
                is_shift=is_shift, shift_div=shift_div, shift_place=shift_place,
                non_local='_nl' in this_weights,
                )

        if 'tpool' in this_weights:
            from ops.temporal_shift import make_temporal_pool
            make_temporal_pool(net.base_model, this_test_segments)  # since DataParallel

        checkpoint = torch.load(this_weights)
        checkpoint = checkpoint['state_dict']

        # base_dict = {('base_model.' + k).replace('base_model.fc', 'new_fc'): v for k, v in list(checkpoint.items())}
        base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())}
        replace_dict = {'base_model.classifier.weight': 'new_fc.weight',
                        'base_model.classifier.bias': 'new_fc.bias',
                        }
        for k, v in replace_dict.items():
            if k in base_dict:
                base_dict[v] = base_dict.pop(k)

        net.load_state_dict(base_dict)

        input_size = net.scale_size if args.full_res else net.input_size
        if args.test_crops == 1:
            cropping = torchvision.transforms.Compose([
                GroupScale(net.scale_size),
                GroupCenterCrop(input_size),
            ])
        elif args.test_crops == 3:  # do not flip, so only 5 crops
            cropping = torchvision.transforms.Compose([
                GroupFullResSample(input_size, net.scale_size, flip=False)
            ])
        elif args.test_crops == 5:  # do not flip, so only 5 crops
            cropping = torchvision.transforms.Compose([
                GroupOverSample(input_size, net.scale_size, flip=False)
            ])
        elif args.test_crops == 10:
            cropping = torchvision.transforms.Compose([
                GroupOverSample(input_size, net.scale_size)
            ])
        else:
            raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(args.test_crops))

        data_loader = torch.utils.data.DataLoader(
                TSNDataSet(root_path, test_file if test_file is not None else val_list, num_segments=this_test_segments,
                        new_length=1 if modality == "RGB" else 5,
                        modality=modality,
                        image_tmpl=prefix,
                        test_mode=True,
                        remove_missing=len(weights_list) == 1,
                        transform=torchvision.transforms.Compose([
                            cropping,
                            Stack(roll=(this_arch in ['BNInception', 'InceptionV3'])),
                            ToTorchFormatTensor(div=(this_arch not in ['BNInception', 'InceptionV3'])),
                            GroupNormalize(net.input_mean, net.input_std),
                        ]), dense_sample=args.dense_sample, twice_sample=args.twice_sample),
                batch_size=args.batch_size, shuffle=False,
                num_workers=args.workers, pin_memory=True,
        )

        if args.gpus is not None:
            devices = [args.gpus[i] for i in range(args.workers)]
        else:
            devices = list(range(args.workers))

        net = torch.nn.DataParallel(net.cuda())
        net.eval()

        data_gen = enumerate(data_loader)

        if total_num is None:
            total_num = len(data_loader.dataset)
        else:
            assert total_num == len(data_loader.dataset)

        data_iter_list.append(data_gen)
        net_list.append(net)


    output = []


    def eval_video(video_data, net, this_test_segments, modality):
        net.eval()
        with torch.no_grad():
            i, data, label = video_data
            batch_size = label.numel()
            num_crop = args.test_crops
            if args.dense_sample:
                num_crop *= 10  # 10 clips for testing when using dense sample

            if args.twice_sample:
                num_crop *= 2

            if modality == 'RGB':
                length = 3
            elif modality == 'Flow':
                length = 10
            elif modality == 'RGBDiff':
                length = 18
            else:
                raise ValueError("Unknown modality "+ modality)

            data_in = data.view(-1, length, data.size(2), data.size(3))
            if is_shift:
                data_in = data_in.view(batch_size * num_crop, this_test_segments, length, data_in.size(2), data_in.size(3))
            rst = net(data_in)
            rst = rst.reshape(batch_size, num_crop, -1).mean(1)

            if args.softmax:
                # take the softmax to normalize the output to probability
                rst = F.softmax(rst, dim=1)

            rst = rst.data.cpu().numpy().copy()

            if net.module.is_shift:
                rst = rst.reshape(batch_size, num_class)
            else:
                rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class))

            return i, rst, label


    proc_start_time = time.time()
    max_num = args.max_num if args.max_num > 0 else total_num

    top1 = AverageMeter()
    top5 = AverageMeter()

    for i, data_label_pairs in enumerate(zip(*data_iter_list)):
        with torch.no_grad():
            if i >= max_num:
                break
            this_rst_list = []
            this_label = None
            for n_seg, (_, (data, label)), net, modality in zip(test_segments_list, data_label_pairs, net_list, modality_list):
                rst = eval_video((i, data, label), net, n_seg, modality)
                this_rst_list.append(rst[1])
                this_label = label
            assert len(this_rst_list) == len(coeff_list)
            for i_coeff in range(len(this_rst_list)):
                this_rst_list[i_coeff] *= coeff_list[i_coeff]
            ensembled_predict = sum(this_rst_list) / len(this_rst_list)

            for p, g in zip(ensembled_predict, this_label.cpu().numpy()):
                output.append([p[None, ...], g])
            cnt_time = time.time() - proc_start_time
            prec1, prec5 = accuracy(torch.from_numpy(ensembled_predict), this_label, topk=(1, 5))
            top1.update(prec1.item(), this_label.numel())
            top5.update(prec5.item(), this_label.numel())
            if i % 20 == 0:
                print('video {} done, total {}/{}, average {:.3f} sec/video, '
                    'moving Prec@1 {:.3f} Prec@5 {:.3f}'.format(i * args.batch_size, i * args.batch_size, total_num,
                                                                float(cnt_time) / (i+1) / args.batch_size, top1.avg, top5.avg))

    video_pred = [np.argmax(x[0]) for x in output]
    video_pred_top5 = [np.argsort(np.mean(x[0], axis=0).reshape(-1))[::-1][:5] for x in output]

    video_labels = [x[1] for x in output]


    if args.csv_file is not None:
        print('=> Writing result to csv file: {}'.format(args.csv_file))
        with open(test_file_list[0].replace('test_videofolder.txt', 'category.txt')) as f:
            categories = f.readlines()
        categories = [f.strip() for f in categories]
        with open(test_file_list[0]) as f:
            vid_names = f.readlines()
        vid_names = [n.split(' ')[0] for n in vid_names]
        print(vid_names)
        assert len(vid_names) == len(video_pred)
        if args.dataset != 'somethingv2':  # only output top1
            with open(args.csv_file, 'w') as f:
                for n, pred in zip(vid_names, video_pred):
                    f.write('{};{}\n'.format(n, categories[pred]))
        else:
            with open(args.csv_file, 'w') as f:
                for n, pred5 in zip(vid_names, video_pred_top5):
                    fill = [n]
                    for p in list(pred5):
                        fill.append(p)
                    f.write('{};{};{};{};{};{}\n'.format(*fill))


    cf = confusion_matrix(video_labels, video_pred).astype(float)

    np.save('cm.npy', cf)
    cls_cnt = cf.sum(axis=1)
    cls_hit = np.diag(cf)

    cls_acc = cls_hit / cls_cnt
    print(cls_acc)
    upper = np.mean(np.max(cf, axis=1) / cls_cnt)
    print('upper bound: {}'.format(upper))

    print('-----Evaluation is finished------')
    print('Class Accuracy {:.02f}%'.format(np.mean(cls_acc) * 100))
    print('Overall Prec@1 {:.02f}% Prec@5 {:.02f}%'.format(top1.avg, top5.avg))
Example #9
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset,
                                                                                                      args.modality)
    model = TSN(num_class, args.num_segments, args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place,
                fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
                temporal_pool=args.temporal_pool,
                non_local=args.non_local,
                cca3d = args.cca3d
                )

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation(flip=False if 'something' in args.dataset or 'jester' in args.dataset else True)

    model = model.cuda()

    if args.resume:
        if args.temporal_pool:  # early temporal pool so that we can load the state_dict
            make_temporal_pool(model.module.base_model, args.num_segments)
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume,map_location='cuda:0')
            parallel_state_dict = checkpoint['state_dict']
            cpu_state_dict={}
            for k,v in parallel_state_dict.items():
                cpu_state_dict[k[len('module.'):]] = v
            model.load_state_dict(cpu_state_dict)
            print(("=> loaded checkpoint '{}' (epoch {})".format(args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5
    preprocess=torchvision.transforms.Compose([
                       GroupScale(int(scale_size)),
                       GroupCenterCrop(crop_size),
                       Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                       ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
                       normalize,
                   ])
    val_loader = torch.utils.data.DataLoader(
        TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl=prefix,
                   random_shift=False,
                   transform = preprocess,
                   dense_sample=args.dense_sample),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=False)
    norm_param = (input_mean, input_std)
    cam_process(val_loader, model,norm_param)
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=rank)
    else:
        rank = 0
    # create model
    num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset, args.modality)
    full_arch_name = args.arch
    if args.shift:
        full_arch_name += '_shift{}_{}'.format(args.shift_div, args.shift_place)
    if args.temporal_pool:
        full_arch_name += '_tpool'
    args.store_name = '_'.join(
        ['TSM', args.dataset, args.modality, full_arch_name, args.consensus_type, 'segment%d' % args.num_segments,
         'e{}'.format(args.epochs)])
    if args.pretrain != 'imagenet':
        args.store_name += '_{}'.format(args.pretrain)
    if args.lr_type != 'step':
        args.store_name += '_{}'.format(args.lr_type)
    if args.dense_sample:
        args.store_name += '_dense'
    if args.non_local > 0:
        args.store_name += '_nl'
    args.store_name += '_lr{}'.format(args.lr)
    args.store_name += '_wd{:.1e}'.format(args.weight_decay)
    args.store_name += '_do{}'.format(args.dropout)
    if args.suffix is not None:
        args.store_name += '_{}'.format(args.suffix)
    print('storing name: ' + args.store_name)

    check_rootfolders(args, rank)

    model = TSN(num_class, args.num_segments, args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place,
                fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
                temporal_pool=args.temporal_pool,
                non_local=args.non_local)
    
    # first synchronization of initial weights
    # sync_initial_weights(model, args.rank, args.world_size)
    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    if rank == 0:
        print(model)
    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation(flip=False if 'something' in args.dataset or 'jester' in args.dataset else True)

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have on a node
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    optimizer = torch.optim.SGD(policies, args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            if args.start_epoch == 1:
                args.start_epoch = checkpoint['epoch'] + 1
            best_acc1 = checkpoint['best_acc1']
#             if args.gpu is not None:
#                 # best_acc1 may be from a checkpoint from a different GPU
#                 best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_dataset = TSNDataSet(args.dataset, args.root_path, args.train_list, num_segments=args.num_segments,
                       new_length=data_length,
                       modality=args.modality,
                       image_tmpl=prefix,
                       transform=torchvision.transforms.Compose([train_augmentation,
                           Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                           ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
                           normalize]), 
                      dense_sample=args.dense_sample)
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, drop_last=True, sampler=train_sampler)
    
    val_loader = torch.utils.data.DataLoader(
        TSNDataSet(args.dataset, args.root_path, args.val_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl=prefix,
                   random_shift=False,
                   transform=torchvision.transforms.Compose([
                       GroupScale(int(scale_size)),
                       GroupCenterCrop(crop_size),
                       Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                       ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
                       normalize,
                   ]), dense_sample=args.dense_sample),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return
    
    log_training = open(os.path.join(args.root_model, args.store_name, 'log.csv'), 'w')
    with open(os.path.join(args.root_model, args.store_name, 'args.txt'), 'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(log_dir=os.path.join(args.root_model, args.store_name))
    for epoch in range(args.start_epoch, args.epochs+1):
        if args.distributed:
            train_sampler.set_epoch(epoch)
#         adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps, args)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer, args, rank)
        if rank % ngpus_per_node == 0:
            save_checkpoint({
                'epoch': epoch,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer' : optimizer.state_dict(),
            }, False, args, rank)
        
        if epoch % 5 == 0 and rank % ngpus_per_node == 0:
            save_checkpoint({
                'epoch': epoch,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }, False, args, rank, e=epoch)

        # evaluate on validation set
        is_best = False
        if epoch % args.eval_freq == 0 or epoch == args.epochs:
            acc1 = validate(val_loader, model, criterion, epoch, args, rank, log_training, tf_writer)

            # remember best acc@1 and save checkpoint
            is_best = acc1 > best_acc1
            best_acc1 = max(acc1, best_acc1)

        if not args.multiprocessing_distributed or (args.multiprocessing_distributed and rank % ngpus_per_node == 0):
            save_checkpoint({
                'epoch': epoch,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer' : optimizer.state_dict(),
            }, is_best, args, rank)
        [GroupOverSample(input_size, net.scale_size)])
else:
    raise ValueError(
        "Only 1, 5, 10 crops are supported while we got {}".format(
            args.test_crops))

data_loader = torch.utils.data.DataLoader(
    TSNDataSet(
        val_list,
        num_segments=args.test_segments,
        new_length=1 if args.modality == "RGB" else 5,
        modality=args.modality,
        image_tmpl=prefix,
        test_mode=True,
        remove_missing=True,
        multi_clip_test=args.multi_clip_test,
        transform=torchvision.transforms.Compose([
            cropping,
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            GroupNormalize(net.input_mean, net.input_std),
        ]),
        dense_sample=args.dense_sample,
    ),
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=args.workers,
    pin_memory=True,
)

if args.gpus is not None:
Example #12
0
def main():
    global args, best_prec1, least_loss
    least_loss = 1000
    args = parser.parse_args()
    if os.path.exists(os.path.join(args.root_log, "error.log")):
        os.remove(os.path.join(args.root_log, "error.log"))
    logging.basicConfig(
        level=logging.DEBUG,
        filename=os.path.join(args.root_log, "error.log"),
        filemode='a',
        format=
        '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
    )

    # log_handler = open(os.path.join(args.root_log,"error.log"),"w")
    # sys.stdout = log_handler
    if args.root_path:
        num_class, args.train_list, args.val_list, _, prefix = dataset_config.return_dataset(
            args.dataset, args.modality)
        args.train_list = os.path.join(args.root_log,
                                       "kf1_train_anno_lijun_iod.json")
        args.test_list = os.path.join(args.root_log,
                                      "kf1_test_anno_lijun_iod.json")
    else:
        num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(
            args.dataset, args.modality)

    full_arch_name = args.arch
    if args.shift:
        full_arch_name += '_shift{}_{}'.format(args.shift_div,
                                               args.shift_place)
    if args.temporal_pool:
        full_arch_name += '_tpool'
    args.store_name = '_'.join([
        'TSA', args.dataset, args.modality, full_arch_name,
        args.consensus_type,
        'segment%d' % args.num_segments, 'e{}'.format(args.epochs)
    ])
    # if args.pretrain != 'imagenet':
    #     args.store_name += '_{}'.format(args.pretrain)
    if args.lr_type != 'step':
        args.store_name += '_{}'.format(args.lr_type)
    if args.dense_sample:
        args.store_name += '_dense'
    if args.non_local > 0:
        args.store_name += '_nl'
    if args.suffix is not None:
        args.store_name += '_{}'.format(args.suffix)
    print('storing name: ' + args.store_name)

    check_rootfolders()

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                is_shift=args.shift,
                shift_div=args.shift_div,
                shift_place=args.shift_place,
                fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
                temporal_pool=args.temporal_pool,
                non_local=args.non_local,
                is_TSA=args.tsa,
                is_sTSA=args.stsa,
                is_tTSA=args.ttsa,
                shift_diff=args.shift_diff,
                shift_groups=args.shift_groups,
                is_ME=args.me,
                is_3D=args.is_3D,
                cfg_file=args.cfg_file)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    # policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation(
        flip=False
        if 'something' in args.dataset or 'jester' in args.dataset else True)

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()
    if args.optimizer == "sgd":
        if args.lr_scheduler:
            optimizer = torch.optim.SGD(model.parameters(),
                                        args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
        else:
            optimizer = torch.optim.SGD(policies,
                                        args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
    elif args.optimizer == "adam":
        params = get_vmz_fine_tuning_parameters(model,
                                                args.vmz_tune_last_k_layer)
        optimizer = torch.optim.Adam(params,
                                     args.lr,
                                     weight_decay=args.weight_decay)
    else:
        raise RuntimeError("not supported optimizer")

    if args.lr_scheduler:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, args.lr_steps, args.lr_scheduler_gamma)

    if args.resume:
        if args.temporal_pool:  # early temporal pool so that we can load the state_dict
            make_temporal_pool(model.module.base_model, args.num_segments)
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))

            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            # if args.lr_scheduler:
            #     scheduler.load_state_dict(checkpoint["lr_scheduler"])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
            logging.info(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))
            logging.error(
                ("=> no checkpoint found at '{}'".format(args.resume)))

    if args.tune_from:
        print(("=> fine-tuning from '{}'".format(args.tune_from)))
        sd = torch.load(args.tune_from)
        sd = sd['state_dict']
        model_dict = model.state_dict()
        replace_dict = []
        for k, v in sd.items():
            if k not in model_dict and k.replace('.net', '') in model_dict:
                print('=> Load after remove .net: ', k)
                replace_dict.append((k, k.replace('.net', '')))
        for k, v in model_dict.items():
            if k not in sd and k.replace('.net', '') in sd:
                print('=> Load after adding .net: ', k)
                replace_dict.append((k.replace('.net', ''), k))

        for k, k_new in replace_dict:
            sd[k_new] = sd.pop(k)
        keys1 = set(list(sd.keys()))
        keys2 = set(list(model_dict.keys()))
        set_diff = (keys1 - keys2) | (keys2 - keys1)
        print('#### Notice: keys that failed to load: {}'.format(set_diff))
        # sd = {k:v for k, v in sd.items() if k in keys2}
        sd = {k: v for k, v in sd.items() if k in keys2}
        if args.dataset not in args.tune_from:  # new dataset
            print('=> New dataset, do not load fc weights')
            sd = {
                k: v
                for k, v in sd.items()
                if 'fc' not in k and "projection" not in k
            }
        if args.modality == 'Flow' and 'Flow' not in args.tune_from:
            sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k}
        model_dict.update(sd)
        model.load_state_dict(model_dict)

    if args.temporal_pool and not args.resume:
        make_temporal_pool(model.module.base_model, args.num_segments)

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality in ['RGB', "PoseAction"]:
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5
    if not args.shuffle:
        train_loader = torch.utils.data.DataLoader(
            TSNDataSet(args.root_path,
                       args.train_list,
                       num_segments=args.num_segments,
                       new_length=data_length,
                       modality=args.modality,
                       image_tmpl=prefix,
                       transform=torchvision.transforms.Compose([
                           train_augmentation,
                           Stack(roll=(args.arch
                                       in ['BNInception', 'InceptionV3']),
                                 inc_dim=(args.arch in ["R2plus1D", "X3D"])),
                           ToTorchFormatTensor(
                               div=(args.arch
                                    not in ['BNInception', 'InceptionV3']),
                               inc_dim=(args.arch in ["R2plus1D", "X3D"])),
                           normalize,
                       ]),
                       dense_sample=args.dense_sample,
                       all_sample=args.all_sample),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=True,
            drop_last=True)  # prevent something not % n_GPU

        val_loader = torch.utils.data.DataLoader(TSNDataSet(
            args.root_path,
            args.val_list,
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            image_tmpl=prefix,
            random_shift=False,
            transform=torchvision.transforms.Compose([
                GroupScale(scale_size),
                GroupCenterCrop(crop_size),
                Stack(roll=(args.arch in ['BNInception', 'InceptionV3']),
                      inc_dim=(args.arch in ["R2plus1D", "X3D"])),
                ToTorchFormatTensor(div=(args.arch
                                         not in ['BNInception',
                                                 'InceptionV3']),
                                    inc_dim=(args.arch in ["R2plus1D",
                                                           "X3D"])),
                normalize,
            ]),
            dense_sample=args.dense_sample,
            all_sample=args.all_sample),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)

    # for group in policies:
    #     print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
    #         group['name'], len(group['params']), group['lr_mult'], group['decay_mult'])))

    if args.evaluate:
        if args.loss_type == 'nll':
            criterion = torch.nn.CrossEntropyLoss().cuda()
        elif args.loss_type == "bce":
            criterion = torch.nn.BCEWithLogitsLoss().cuda()
        elif args.loss_type == "wbce":
            class_weight, pos_weight = prep_weight(args.train_list)
            criterion = WeightedBCEWithLogitsLoss(class_weight, pos_weight)
        else:
            raise ValueError("Unknown loss type")

        val_loader = torch.utils.data.DataLoader(TSNDataSet(
            args.root_path,
            args.val_list,
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            image_tmpl=prefix,
            random_shift=False,
            transform=torchvision.transforms.Compose([
                GroupScale(scale_size),
                GroupCenterCrop(crop_size),
                Stack(roll=(args.arch in ['BNInception', 'InceptionV3']),
                      inc_dim=(args.arch in ["R2plus1D", "X3D"])),
                ToTorchFormatTensor(div=(args.arch
                                         not in ['BNInception',
                                                 'InceptionV3']),
                                    inc_dim=(args.arch in ["R2plus1D",
                                                           "X3D"])),
                normalize,
            ]),
            dense_sample=args.dense_sample,
            all_sample=args.all_sample,
            analysis=True),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
        test(val_loader, model, criterion, 0)
        return

    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))

    print(model)
    logging.info(model)
    for epoch in range(args.start_epoch, args.epochs):
        logging.info("Train Epoch {}/{} starts, estimated time 5832s".format(
            str(epoch), str(args.epochs)))
        # update data_loader
        if args.shuffle:
            gen_label(args.prop_path,
                      args.label_path,
                      args.trn_name,
                      args.train_list,
                      args.neg_rate,
                      STR=False)
            gen_label(args.prop_path,
                      args.label_path,
                      args.tst_name,
                      args.val_list,
                      args.test_rate,
                      STR=False)
            train_loader = torch.utils.data.DataLoader(
                TSNDataSet(args.root_path,
                           args.train_list,
                           num_segments=args.num_segments,
                           new_length=data_length,
                           modality=args.modality,
                           image_tmpl=prefix,
                           transform=torchvision.transforms.Compose([
                               train_augmentation,
                               Stack(roll=(args.arch
                                           in ['BNInception', 'InceptionV3']),
                                     inc_dim=(args.arch in ["R2plus1D",
                                                            "X3D"])),
                               ToTorchFormatTensor(
                                   div=(args.arch
                                        not in ['BNInception', 'InceptionV3']),
                                   inc_dim=(args.arch in ["R2plus1D", "X3D"])),
                               normalize,
                           ]),
                           dense_sample=args.dense_sample,
                           all_sample=args.all_sample),
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True,
                drop_last=True)

            val_loader = torch.utils.data.DataLoader(
                TSNDataSet(args.root_path,
                           args.val_list,
                           num_segments=args.num_segments,
                           new_length=data_length,
                           modality=args.modality,
                           image_tmpl=prefix,
                           random_shift=False,
                           transform=torchvision.transforms.Compose([
                               GroupScale(scale_size),
                               GroupCenterCrop(crop_size),
                               Stack(roll=(args.arch
                                           in ['BNInception', 'InceptionV3']),
                                     inc_dim=(args.arch in ["R2plus1D",
                                                            "X3D"])),
                               ToTorchFormatTensor(
                                   div=(args.arch
                                        not in ['BNInception', 'InceptionV3']),
                                   inc_dim=(args.arch in ["R2plus1D", "X3D"])),
                               normalize,
                           ]),
                           dense_sample=args.dense_sample,
                           all_sample=args.all_sample),
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=True)
            print(train_loader)

            # define loss function (criterion) and optimizer
        if args.loss_type == 'nll':
            criterion = torch.nn.CrossEntropyLoss().cuda()
        elif args.loss_type == "bce":
            criterion = torch.nn.BCEWithLogitsLoss().cuda()
        elif args.loss_type == "wbce":
            class_weight, pos_weight = prep_weight(args.train_list)
            criterion = WeightedBCEWithLogitsLoss(class_weight, pos_weight)
        else:
            raise ValueError("Unknown loss type")

        if not args.lr_scheduler:
            adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)
            train(train_loader, model, criterion, optimizer, epoch,
                  log_training, tf_writer)

        else:
            train(train_loader, model, criterion, optimizer, epoch,
                  log_training, tf_writer)
            scheduler.step()

        # train for one epoch

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            logging.info(
                "Test Epoch {}/{} starts, estimated time 13874s".format(
                    str(epoch // args.eval_freq),
                    str(args.epochs / args.eval_freq)))
            if args.loss_type == "wbce":
                # class_weight,pos_weight = prep_weight(args.val_list)
                criterion = torch.nn.BCEWithLogitsLoss().cuda()
            lossm = validate(val_loader, model, criterion, epoch, log_training,
                             tf_writer)

            # remember best prec@1 and save checkpoint
            is_best = lossm < least_loss
            least_loss = min(lossm, least_loss)
            tf_writer.add_scalar('lss/test_top1_best', least_loss, epoch)

            output_best = 'Best Loss: %.3f\n' % (lossm)
            logging.info(output_best)
            log_training.write(output_best + '\n')
            log_training.flush()
            if args.lr_scheduler:
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'best_prec1': least_loss,
                        'lr_scheduler': scheduler,
                    }, is_best, epoch)
            else:
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'best_prec1': least_loss,
                    }, is_best, epoch)
Example #13
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(
        args.dataset, args.modality)
    full_arch_name = args.arch
    if args.temporal_pool:
        full_arch_name += '_tpool'
    args.store_name = '_'.join([
        'I3D', args.dataset, full_arch_name, 'batch{}'.format(args.batch_size),
        'wd{}'.format(args.weight_decay), args.consensus_type,
        'segment%d' % args.num_segments, 'e{}'.format(args.epochs),
        'dropout{}'.format(args.dropout), args.pretrain,
        'lr{}'.format(args.lr), '_warmup{}'.format(args.warmup)
    ])
    if args.lr_type != 'step':
        args.store_name += '_{}'.format(args.lr_type)
    else:
        step_str = [str(int(x)) for x in args.lr_steps]
        args.store_name += '_step' + '_'.join(step_str)
    if args.dense_sample:
        args.store_name += '_dense'
    if args.spatial_dropout:
        sigmoid_layer_str = '_'.join(args.sigmoid_layer)
        args.store_name += '_spatial_drop3d_{}_group{}_layer{}'.format(
            args.sigmoid_thres, args.sigmoid_group, sigmoid_layer_str)
        if args.sigmoid_random:
            args.store_name += '_RandomSigmoid'
    if args.non_local > 0:
        args.store_name += '_nl'
    if args.suffix is not None:
        args.store_name += '_{}'.format(args.suffix)
    print('storing name: ' + args.store_name)

    check_rootfolders()

    model = i3d(num_class,
                args.num_segments,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                spatial_dropout=args.spatial_dropout,
                sigmoid_thres=args.sigmoid_thres,
                sigmoid_group=args.sigmoid_group,
                sigmoid_random=args.sigmoid_random,
                sigmoid_layer=args.sigmoid_layer,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    train_augmentation = model.get_augmentation(
        flip=False
        if 'something' in args.dataset or 'jester' in args.dataset else True)

    model = torch.nn.DataParallel(model,
                                  device_ids=list(range(args.gpus))).cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.resume:
        if args.temporal_pool:  # early temporal pool so that we can load the state_dict
            make_temporal_pool(model.module.base_model, args.num_segments)
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    if args.tune_from:
        print(("=> fine-tuning from '{}'".format(args.tune_from)))
        sd = torch.load(args.tune_from)
        sd = sd['state_dict']
        model_dict = model.state_dict()
        replace_dict = []
        for k, v in sd.items():
            if k not in model_dict and k.replace('.net', '') in model_dict:
                print('=> Load after remove .net: ', k)
                replace_dict.append((k, k.replace('.net', '')))
        for k, v in model_dict.items():
            if k not in sd and k.replace('.net', '') in sd:
                print('=> Load after adding .net: ', k)
                replace_dict.append((k.replace('.net', ''), k))

        for k, k_new in replace_dict:
            sd[k_new] = sd.pop(k)
        keys1 = set(list(sd.keys()))
        keys2 = set(list(model_dict.keys()))
        set_diff = (keys1 - keys2) | (keys2 - keys1)
        print('#### Notice: keys that failed to load: {}'.format(set_diff))
        if args.dataset not in args.tune_from:  # new dataset
            print('=> New dataset, do not load fc weights')
            sd = {k: v for k, v in sd.items() if 'fc' not in k}
        if args.modality == 'Flow' and 'Flow' not in args.tune_from:
            sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k}
        model_dict.update(sd)
        model.load_state_dict(model_dict)

    if args.temporal_pool and not args.resume:
        make_temporal_pool(model.module.base_model, args.num_segments)

    cudnn.benchmark = True

    # Data loading code
    normalize = GroupNormalize(input_mean, input_std)

    train_loader = torch.utils.data.DataLoader(
        TSNDataSet(args.root_path,
                   args.train_list,
                   num_segments=args.num_segments,
                   image_tmpl=prefix,
                   transform=torchvision.transforms.Compose([
                       GroupScale((256, 340)),
                       train_augmentation,
                       Stack('3D'),
                       ToTorchFormatTensor(),
                       normalize,
                   ]),
                   dense_sample=args.dense_sample),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=True)  # prevent something not % n_GPU

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.val_list,
        num_segments=args.num_segments,
        image_tmpl=prefix,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack('3D'),
            ToTorchFormatTensor(),
            normalize,
        ]),
        dense_sample=args.dense_sample),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.BCEWithLogitsLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))
    for epoch in range(args.start_epoch, args.epochs):
        # adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training,
              tf_writer, args.warmup, args.lr_type, args.lr_steps)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion, epoch, log_training,
                             tf_writer)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            tf_writer.add_scalar('acc/test_mAP_best', best_prec1, epoch)

            output_best = 'Best mAP: %.3f\n' % (best_prec1)
            print(output_best)
            log_training.write(output_best + '\n')
            log_training.flush()

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
Example #14
0
            [GroupOverSample(input_size, net.scale_size)])
    else:
        raise ValueError(
            "Only 1, 5, 10 crops are supported while we got {}".format(
                args.test_crops))

    data_loader = torch.utils.data.DataLoader(
        TSNDataSet(
            root_path,
            test_file if test_file is not None else val_list,
            num_segments=this_test_segments,
            new_length=1 if modality == "RGB" else 5,
            modality=modality,
            image_tmpl=prefix,
            test_mode=True,
            remove_missing=len(weights_list) == 1,
            transform=torchvision.transforms.Compose([
                cropping,
                Stack(roll=(this_arch in ['BNInception', 'InceptionV3'])),
                ToTorchFormatTensor(
                    div=(this_arch not in ['BNInception', 'InceptionV3'])),
                GroupNormalize(net.input_mean, net.input_std),
            ]),
            dense_sample=args.dense_sample,
            twice_sample=args.twice_sample),
        batch_size=args.batch_size,
        shuffle=False,
        # num_workers=args.workers, pin_memory=True,
    )

    if args.gpus is not None:
        devices = [args.gpus[i] for i in range(args.workers)]
Example #15
0
File: main.py Project: CV-IP/TDN
def main():
    global args, best_prec1
    args = parser.parse_args()

    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group(backend='nccl', init_method='env://')

    num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(
        args.dataset, args.modality)
    full_arch_name = args.arch
    args.store_name = '_'.join([
        'TDN_', args.dataset, args.modality, full_arch_name,
        args.consensus_type,
        'segment%d' % args.num_segments, 'e{}'.format(args.epochs)
    ])
    if args.pretrain != 'imagenet':
        args.store_name += '_{}'.format(args.pretrain)
    if args.dense_sample:
        args.store_name += '_dense'
    if args.suffix is not None:
        args.store_name += '_{}'.format(args.suffix)

    if dist.get_rank() == 0:
        check_rootfolders()

    logger = setup_logger(output=os.path.join(args.root_log, args.store_name),
                          distributed_rank=dist.get_rank(),
                          name=f'TDN')
    logger.info('storing name: ' + args.store_name)

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                fc_lr5=(args.tune_from and args.dataset in args.tune_from))

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    for group in policies:
        logger.info(
            ('[TDN-{}]group: {} has {} params, lr_mult: {}, decay_mult: {}'.
             format(args.arch, group['name'], len(group['params']),
                    group['lr_mult'], group['decay_mult'])))

    train_augmentation = model.get_augmentation(
        flip=False if 'something' in args.dataset else True)

    cudnn.benchmark = True

    # Data loading code
    normalize = GroupNormalize(input_mean, input_std)

    train_dataset = TSNDataSet(
        args.dataset,
        args.root_path,
        args.train_list,
        num_segments=args.num_segments,
        modality=args.modality,
        image_tmpl=prefix,
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ]),
        dense_sample=args.dense_sample)

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True)

    val_dataset = TSNDataSet(
        args.dataset,
        args.root_path,
        args.val_list,
        num_segments=args.num_segments,
        modality=args.modality,
        image_tmpl=prefix,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ]),
        dense_sample=args.dense_sample)

    val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             sampler=val_sampler,
                                             drop_last=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    scheduler = get_scheduler(optimizer, len(train_loader), args)

    model = DistributedDataParallel(model.cuda(),
                                    device_ids=[args.local_rank],
                                    broadcast_buffers=True,
                                    find_unused_parameters=True)

    if args.resume:
        if os.path.isfile(args.resume):
            logger.info(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume, map_location='cpu')
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            logger.info(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            logger.info(("=> no checkpoint found at '{}'".format(args.resume)))

    if args.tune_from:
        logger.info(("=> fine-tuning from '{}'".format(args.tune_from)))
        sd = torch.load(args.tune_from)
        sd = sd['state_dict']
        model_dict = model.state_dict()
        replace_dict = []
        for k, v in sd.items():
            if k not in model_dict and k.replace('.net', '') in model_dict:
                logger.info('=> Load after remove .net: ', k)
                replace_dict.append((k, k.replace('.net', '')))
        for k, v in model_dict.items():
            if k not in sd and k.replace('.net', '') in sd:
                logger.info('=> Load after adding .net: ', k)
                replace_dict.append((k.replace('.net', ''), k))

        for k, k_new in replace_dict:
            sd[k_new] = sd.pop(k)
        keys1 = set(list(sd.keys()))
        keys2 = set(list(model_dict.keys()))
        set_diff = (keys1 - keys2) | (keys2 - keys1)
        logger.info(
            '#### Notice: keys that failed to load: {}'.format(set_diff))
        if args.dataset not in args.tune_from:  # new dataset
            logger.info('=> New dataset, do not load fc weights')
            sd = {k: v for k, v in sd.items() if 'fc' not in k}
        model_dict.update(sd)
        model.load_state_dict(model_dict)

    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))

    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))

    if args.evaluate:
        logger.info(("===========evaluate==========="))
        val_loader.sampler.set_epoch(args.start_epoch)
        prec1, prec5, val_loss = validate(val_loader, model, criterion, logger)
        if dist.get_rank() == 0:
            is_best = prec1 > best_prec1
            best_prec1 = prec1
            logger.info(("Best Prec@1: '{}'".format(best_prec1)))
            save_epoch = args.start_epoch + 1
            save_checkpoint(
                {
                    'epoch': args.start_epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'prec1': prec1,
                    'best_prec1': best_prec1,
                }, save_epoch, is_best)
        return

    for epoch in range(args.start_epoch, args.epochs):
        train_loader.sampler.set_epoch(epoch)
        train_loss, train_top1, train_top5 = train(train_loader,
                                                   model,
                                                   criterion,
                                                   optimizer,
                                                   epoch=epoch,
                                                   logger=logger,
                                                   scheduler=scheduler)
        if dist.get_rank() == 0:
            tf_writer.add_scalar('loss/train', train_loss, epoch)
            tf_writer.add_scalar('acc/train_top1', train_top1, epoch)
            tf_writer.add_scalar('acc/train_top5', train_top5, epoch)
            tf_writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], epoch)

        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            val_loader.sampler.set_epoch(epoch)
            prec1, prec5, val_loss = validate(val_loader, model, criterion,
                                              epoch, logger)
            if dist.get_rank() == 0:
                tf_writer.add_scalar('loss/test', val_loss, epoch)
                tf_writer.add_scalar('acc/test_top1', prec1, epoch)
                tf_writer.add_scalar('acc/test_top5', prec5, epoch)

                is_best = prec1 > best_prec1
                best_prec1 = max(prec1, best_prec1)
                tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch)

                logger.info(("Best Prec@1: '{}'".format(best_prec1)))
                tf_writer.flush()
                save_epoch = epoch + 1
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'prec1': prec1,
                        'best_prec1': best_prec1,
                    }, save_epoch, is_best)
Example #16
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(
        args.dataset, args.modality)
    full_arch_name = args.arch
    if args.shift:
        full_arch_name += '_shift{}_{}'.format(args.shift_div,
                                               args.shift_place)
    if args.temporal_pool:
        full_arch_name += '_tpool'
    args.store_name = '_'.join([
        'TSM', args.dataset, args.modality, full_arch_name,
        args.consensus_type,
        'segment%d' % args.num_segments, 'e{}'.format(args.epochs)
    ])
    if args.pretrain != 'imagenet':
        args.store_name += '_{}'.format(args.pretrain)
    if args.lr_type != 'step':
        args.store_name += '_{}'.format(args.lr_type)
    if args.dense_sample:
        args.store_name += '_dense'
    if args.non_local > 0:
        args.store_name += '_nl'
    if args.suffix is not None:
        args.store_name += '_{}'.format(args.suffix)
    print('storing name: ' + args.store_name)

    check_rootfolders()

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                is_shift=args.shift,
                shift_div=args.shift_div,
                shift_place=args.shift_place,
                fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
                temporal_pool=args.temporal_pool,
                non_local=args.non_local)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation(
        flip=False
        if 'something' in args.dataset or 'jester' in args.dataset else True)

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.resume:
        if args.temporal_pool:  # early temporal pool so that we can load the state_dict
            make_temporal_pool(model.module.base_model, args.num_segments)
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    if args.tune_from:
        print(("=> fine-tuning from '{}'".format(args.tune_from)))
        sd = torch.load(args.tune_from)
        sd = sd['state_dict']
        model_dict = model.state_dict()
        replace_dict = []
        for k, v in sd.items():
            if k not in model_dict and k.replace('.net', '') in model_dict:
                print('=> Load after remove .net: ', k)
                replace_dict.append((k, k.replace('.net', '')))
        for k, v in model_dict.items():
            if k not in sd and k.replace('.net', '') in sd:
                print('=> Load after adding .net: ', k)
                replace_dict.append((k.replace('.net', ''), k))

        for k, k_new in replace_dict:
            sd[k_new] = sd.pop(k)
        keys1 = set(list(sd.keys()))
        keys2 = set(list(model_dict.keys()))
        set_diff = (keys1 - keys2) | (keys2 - keys1)
        print('#### Notice: keys that failed to load: {}'.format(set_diff))
        if args.dataset not in args.tune_from:  # new dataset
            print('=> New dataset, do not load fc weights')
            sd = {k: v for k, v in sd.items() if 'fc' not in k}
        if args.modality == 'Flow' and 'Flow' not in args.tune_from:
            sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k}
        model_dict.update(sd)
        model.load_state_dict(model_dict)

    if args.temporal_pool and not args.resume:
        make_temporal_pool(model.module.base_model, args.num_segments)

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_loader = torch.utils.data.DataLoader(
        TSNDataSet(
            args.root_path,
            args.train_list,
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            image_tmpl=prefix,
            transform=torchvision.transforms.Compose([
                train_augmentation,
                Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                ToTorchFormatTensor(
                    div=(args.arch not in ['BNInception', 'InceptionV3'])),
                normalize,
            ]),
            dense_sample=args.dense_sample),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=True)  # prevent something not % n_GPU

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ]),
        dense_sample=args.dense_sample),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training,
              tf_writer)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion, epoch, log_training,
                             tf_writer)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch)

            output_best = 'Best Prec@1: %.3f\n' % (best_prec1)
            print(output_best)
            log_training.write(output_best + '\n')
            log_training.flush()

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
Example #17
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    #num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset,
    #                                                                                                  args.modality)
    num_class = 21
    args.train_list = "/home/jzwang/code/Video_3D/movienet/data/movie/movie_val_new.txt"
    args.val_list = "/home/jzwang/code/RGB-FLOW/movie_new/data/datanew/movie_val_new.txt"
    args.root_path = ""
    prefix = "frame_{:04d}.jpg"
    full_arch_name = args.arch
    if args.shift:
        full_arch_name += '_shift{}_{}'.format(args.shift_div,
                                               args.shift_place)
    if args.temporal_pool:
        full_arch_name += '_tpool'
    args.store_name = '_'.join([
        'TSM', args.dataset, args.modality, full_arch_name,
        args.consensus_type,
        'segment%d' % args.num_segments, 'e{}'.format(args.epochs)
    ])
    if args.pretrain != 'imagenet':
        args.store_name += '_{}'.format(args.pretrain)
    if args.lr_type != 'step':
        args.store_name += '_{}'.format(args.lr_type)
    if args.dense_sample:
        args.store_name += '_dense'
    if args.non_local > 0:
        args.store_name += '_nl'
    if args.suffix is not None:
        args.store_name += '_{}'.format(args.suffix)
    print('storing name: ' + args.store_name)

    #check_rootfolders()

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                is_shift=args.shift,
                shift_div=args.shift_div,
                shift_place=args.shift_place,
                fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
                temporal_pool=args.temporal_pool,
                non_local=args.non_local)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation(
        flip=False
        if 'something' in args.dataset or 'jester' in args.dataset else True)

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.resume:
        if args.temporal_pool:  # early temporal pool so that we can load the state_dict
            make_temporal_pool(model.module.base_model, args.num_segments)
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            #best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    if args.tune_from:
        print(("=> fine-tuning from '{}'".format(args.tune_from)))
        sd = torch.load(args.tune_from)
        sd = sd['state_dict']
        model_dict = model.state_dict()
        replace_dict = []
        for k, v in sd.items():
            if k not in model_dict and k.replace('.net', '') in model_dict:
                print('=> Load after remove .net: ', k)
                replace_dict.append((k, k.replace('.net', '')))
        for k, v in model_dict.items():
            if k not in sd and k.replace('.net', '') in sd:
                print('=> Load after adding .net: ', k)
                replace_dict.append((k.replace('.net', ''), k))

        for k, k_new in replace_dict:
            sd[k_new] = sd.pop(k)
        keys1 = set(list(sd.keys()))
        keys2 = set(list(model_dict.keys()))
        set_diff = (keys1 - keys2) | (keys2 - keys1)
        print('#### Notice: keys that failed to load: {}'.format(set_diff))
        if args.dataset not in args.tune_from:  # new dataset
            print('=> New dataset, do not load fc weights')
            sd = {k: v for k, v in sd.items() if 'fc' not in k}
        if args.modality == 'Flow' and 'Flow' not in args.tune_from:
            sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k}
        model_dict.update(sd)
        model.load_state_dict(model_dict)

    if args.temporal_pool and not args.resume:
        make_temporal_pool(model.module.base_model, args.num_segments)

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_loader = torch.utils.data.DataLoader(TSNDataSet(
        "",
        args.train_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl="frame_{:04d}.jpg" if args.modality in ["RGB", "RGBDiff"]
        else args.flow_prefix + "{}_{:05d}.jpg",
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=args.arch == 'BNInception'),
            ToTorchFormatTensor(div=args.arch != 'BNInception'),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        "",
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl="frame_{:04d}.jpg" if args.modality in ["RGB", "RGBDiff"]
        else args.flow_prefix + "{}_{:05d}.jpg",
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=args.arch == 'BNInception'),
            ToTorchFormatTensor(div=args.arch != 'BNInception'),
            normalize,
        ])),
                                             batch_size=int(16),
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    criterion = torch.nn.BCEWithLogitsLoss().cuda()
    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    zero_time = time.time()
    best_map = 0
    print('Start training...')
    for epoch in range(args.start_epoch, args.epochs):
        #adjust_learning_rate(optimizer, epoch, args.lr_steps)

        #if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
        valloss, mAP, output_mtx = validate(val_loader, model, criterion)
        end_time = time.time()
        epoch_time = end_time - start_time
        total_time = end_time - zero_time
        print('Total time used: %s Epoch %d time uesd: %s' %
              (str(datetime.timedelta(seconds=int(total_time))), epoch,
               str(datetime.timedelta(seconds=int(epoch_time)))))
        print('Train loss: {0:.4f} val loss: {1:.4f} mAP: {2:.4f}'.format(
            trainloss, valloss, mAP))
        # evaluate on validation set
        is_best = mAP > best_map
        #if mAP > best_map:
        #best_map = mAP
        # checkpoint_name = "%04d_%s" % (epoch+1, "checkpoint.pth.tar")
        checkpoint_name = "best_checkpoint.pth.tar"
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, is_best, epoch)
        np.save("val1.npy", output_mtx)
    with open(args.record_path, 'a') as file:
        file.write(
            'Epoch:[{0}]'
            'Train loss: {1:.4f} val loss: {2:.4f} map: {3:.4f}\n'.format(
                epoch + 1, trainloss, valloss, mAP))

    print('************ Done!... ************')