Esempio n. 1
0
def main():
    t_start = time.time()
    global args, best_prec1, num_class, use_ada_framework  # , model
    wandb.init(
        project="arnet-reproduce",
        name=args.exp_header,
        entity="video_channel"
    )
    wandb.config.update(args)
    set_random_seed(args.random_seed)
    use_ada_framework = args.ada_reso_skip and args.offline_lstm_last == False and args.offline_lstm_all == False and args.real_scsampler == False

    if args.ablation:
        logger = None
    else:
        if not test_mode:
            logger = Logger()
            sys.stdout = logger
        else:
            logger = None

    num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset,
                                                                                                      args.data_dir)
    
    #===
    #args.val_list = args.train_list
    #===

    if args.ada_reso_skip:
        if len(args.ada_crop_list) == 0:
            args.ada_crop_list = [1 for _ in args.reso_list]

    if use_ada_framework:
        init_gflops_table()

    model = TSN_Ada(num_class, args.num_segments,
                    base_model=args.arch,
                    consensus_type=args.consensus_type,
                    dropout=args.dropout,
                    partial_bn=not args.no_partialbn,
                    pretrain=args.pretrain,
                    fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
                    args=args)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation(
        flip=False if 'something' in args.dataset or 'jester' in args.dataset else True)

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()
    # TODO(yue) freeze some params in the policy + lstm layers
    if args.freeze_policy:
        for name, param in model.module.named_parameters():
            if "lite_fc" in name or "lite_backbone" in name or "rnn" in name or "linear" in name:
                param.requires_grad = False

    if args.freeze_backbone:
        for name, param in model.module.named_parameters():
            if "base_model" in name:
                param.requires_grad = False
    if len(args.frozen_list) > 0:
        for name, param in model.module.named_parameters():
            for keyword in args.frozen_list:
                if keyword[0] == "*":
                    if keyword[-1] == "*":  # TODO middle
                        if keyword[1:-1] in name:
                            param.requires_grad = False
                            print(keyword, "->", name, "frozen")
                    else:  # TODO suffix
                        if name.endswith(keyword[1:]):
                            param.requires_grad = False
                            print(keyword, "->", name, "frozen")
                elif keyword[-1] == "*":  # TODO prefix
                    if name.startswith(keyword[:-1]):
                        param.requires_grad = False
                        print(keyword, "->", name, "frozen")
                else:  # TODO exact word
                    if name == keyword:
                        param.requires_grad = False
                        print(keyword, "->", name, "frozen")
        print("=" * 80)
        for name, param in model.module.named_parameters():
            print(param.requires_grad, "\t", name)

        print("=" * 80)
        for name, param in model.module.named_parameters():
            print(param.requires_grad, "\t", name)

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(("=> loaded checkpoint '{}' (epoch {})"
                   .format(args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    if args.tune_from:
        print(("=> fine-tuning from '{}'".format(args.tune_from)))
        sd = torch.load(args.tune_from)
        sd = sd['state_dict']
        model_dict = model.state_dict()
        replace_dict = []
        for k, v in sd.items():
            if k not in model_dict and k.replace('.net', '') in model_dict:
                print('=> Load after remove .net: ', k)
                replace_dict.append((k, k.replace('.net', '')))
        for k, v in model_dict.items():
            if k not in sd and k.replace('.net', '') in sd:
                print('=> Load after adding .net: ', k)
                replace_dict.append((k.replace('.net', ''), k))

        for k, k_new in replace_dict:
            sd[k_new] = sd.pop(k)
        keys1 = set(list(sd.keys()))
        keys2 = set(list(model_dict.keys()))
        set_diff = (keys1 - keys2) | (keys2 - keys1)
        print('#### Notice: keys that failed to load: {}'.format(set_diff))
        if args.dataset not in args.tune_from:  # new dataset
            print('=> New dataset, do not load fc weights')
            sd = {k: v for k, v in sd.items() if 'fc' not in k}
        model_dict.update(sd)
        model.load_state_dict(model_dict)

    # TODO(yue) ada_model loading process
    if args.ada_reso_skip:
        if test_mode:
            print("Test mode load from pretrained model")
            the_model_path = args.test_from
            if ".pth.tar" not in the_model_path:
                the_model_path = ospj(the_model_path, "models", "ckpt.best.pth.tar")
            model_dict = model.state_dict()
            sd = load_to_sd(model_dict, the_model_path, "foo", "bar", -1, apple_to_apple=True)
            model_dict.update(sd)
            model.load_state_dict(model_dict)
        elif args.base_pretrained_from != "":
            print("Adaptively load from pretrained whole")
            model_dict = model.state_dict()
            sd = load_to_sd(model_dict, args.base_pretrained_from, "foo", "bar", -1, apple_to_apple=True)

            model_dict.update(sd)
            model.load_state_dict(model_dict)

        elif len(args.model_paths) != 0:
            print("Adaptively load from model_path_list")
            model_dict = model.state_dict()
            # TODO(yue) policy net
            sd = load_to_sd(model_dict, args.policy_path, "lite_backbone", "lite_fc",
                            args.reso_list[args.policy_input_offset])
            model_dict.update(sd)
            # TODO(yue) backbones
            for i, tmp_path in enumerate(args.model_paths):
                base_model_index = i
                new_i = i

                sd = load_to_sd(model_dict, tmp_path, "base_model_list.%d" % base_model_index, "new_fc_list.%d" % new_i,
                                args.reso_list[i])
                model_dict.update(sd)
            model.load_state_dict(model_dict)
    else:
        if test_mode:
            the_model_path = args.test_from
            if ".pth.tar" not in the_model_path:
                the_model_path = ospj(the_model_path, "models", "ckpt.best.pth.tar")
            model_dict = model.state_dict()
            sd = load_to_sd(model_dict, the_model_path, "foo", "bar", -1, apple_to_apple=True)
            model_dict.update(sd)
            model.load_state_dict(model_dict)

    if args.ada_reso_skip == False and args.base_pretrained_from != "":
        print("Baseline: load from pretrained model")
        model_dict = model.state_dict()
        sd = load_to_sd(model_dict, args.base_pretrained_from, "base_model", "new_fc", 224)

        if args.ignore_new_fc_weight:
            print("@ IGNORE NEW FC WEIGHT !!!")
            del sd["module.new_fc.weight"]
            del sd["module.new_fc.bias"]

        model_dict.update(sd)
        model.load_state_dict(model_dict)

    cudnn.benchmark = True

    # Data loading code
    normalize = GroupNormalize(input_mean, input_std)
    data_length = 1
    train_loader = torch.utils.data.DataLoader(
        TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments,
                   image_tmpl=prefix,
                   transform=torchvision.transforms.Compose([
                       train_augmentation,
                       Stack(roll=False),
                       ToTorchFormatTensor(div=True),
                       normalize,
                   ]), dense_sample=args.dense_sample,
                   dataset=args.dataset,
                   partial_fcvid_eval=args.partial_fcvid_eval,
                   partial_ratio=args.partial_ratio,
                   ada_reso_skip=args.ada_reso_skip,
                   reso_list=args.reso_list,
                   random_crop=args.random_crop,
                   center_crop=args.center_crop,
                   ada_crop_list=args.ada_crop_list,
                   rescale_to=args.rescale_to,
                   policy_input_offset=args.policy_input_offset,
                   save_meta=args.save_meta),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True,
        drop_last=True)  # prevent something not % n_GPU

    val_loader = torch.utils.data.DataLoader(
        TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments,
                   image_tmpl=prefix,
                   random_shift=False,
                   transform=torchvision.transforms.Compose([
                       GroupScale(int(scale_size)),
                       GroupCenterCrop(crop_size),
                       Stack(roll=False),
                       ToTorchFormatTensor(div=True),
                       normalize,
                   ]), dense_sample=args.dense_sample,
                   dataset=args.dataset,
                   partial_fcvid_eval=args.partial_fcvid_eval,
                   partial_ratio=args.partial_ratio,
                   ada_reso_skip=args.ada_reso_skip,
                   reso_list=args.reso_list,
                   random_crop=args.random_crop,
                   center_crop=args.center_crop,
                   ada_crop_list=args.ada_crop_list,
                   rescale_to=args.rescale_to,
                   policy_input_offset=args.policy_input_offset,
                   save_meta=args.save_meta
                   ),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss().cuda()

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    if not test_mode:
        exp_full_path = setup_log_directory(logger, args.log_dir, args.exp_header)
    else:
        exp_full_path = None

    if not args.ablation:
        if not test_mode:
            with open(os.path.join(exp_full_path, 'args.txt'), 'w') as f:
                f.write(str(args))
            tf_writer = SummaryWriter(log_dir=exp_full_path)
        else:
            tf_writer = None
    else:
        tf_writer = None

    # TODO(yue)
    map_record = Recorder()
    mmap_record = Recorder()
    prec_record = Recorder()
    best_train_usage_str = None
    best_val_usage_str = None

    wandb.watch(model)
    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        if not args.skip_training:
            set_random_seed(args.random_seed + epoch)
            adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)
            train_usage_str = train(train_loader, model, criterion, optimizer, epoch, logger, exp_full_path, tf_writer)
        else:
            train_usage_str = "No training usage stats (Eval Mode)"

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            set_random_seed(args.random_seed)
            mAP, mmAP, prec1, val_usage_str, val_gflops = validate(val_loader, model, criterion, epoch, logger,
                                                                   exp_full_path, tf_writer)

            # remember best prec@1 and save checkpoint
            map_record.update(mAP)
            mmap_record.update(mmAP)
            prec_record.update(prec1)

            if mmap_record.is_current_best():
                best_train_usage_str = train_usage_str
                best_val_usage_str = val_usage_str

            print('Best mAP: %.3f (epoch=%d)\t\tBest mmAP: %.3f(epoch=%d)\t\tBest Prec@1: %.3f (epoch=%d)' % (
                map_record.best_val, map_record.best_at,
                mmap_record.best_val, mmap_record.best_at,
                prec_record.best_val, prec_record.best_at))

            if args.skip_training:
                break

            if (not args.ablation) and (not test_mode):
                tf_writer.add_scalar('acc/test_top1_best', prec_record.best_val, epoch)
                save_checkpoint({
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': prec_record.best_val,
                }, mmap_record.is_current_best(), exp_full_path)

    if use_ada_framework and not test_mode:
        print("Best train usage:")
        print(best_train_usage_str)
        print()
        print("Best val usage:")
        print(best_val_usage_str)

    print("Finished in %.4f seconds\n" % (time.time() - t_start))
Esempio n. 2
0
def main():
    global args, best_prec1, TRAIN_SAMPLES
    args = parser.parse_args()

    num_class, args.train_list, args.val_list, args.test_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset,
                                                                                                      args.modality)
    if os.path.exists(args.test_list):
        args.val_list = args.test_list


    model = TSN(num_class, args.num_segments, args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place,
                fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
                temporal_pool=args.temporal_pool,
                non_local=args.non_local,
                tin=args.tin)


    crop_size = args.crop_size
    scale_size = args.scale_size
    input_mean = [0.485, 0.456, 0.406]
    input_std = [0.229, 0.224, 0.225]

    print(args.gpus)
    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    if os.path.isfile(args.resume_path):
        print(("=> loading checkpoint '{}'".format(args.resume_path)))
        checkpoint = torch.load(args.resume_path)
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        print(("=> loaded checkpoint '{}' (epoch {})"
               .format(args.evaluate, checkpoint['epoch'])))
    else:
        print(("=> no checkpoint found at '{}'".format(args.resume_path)))


    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    if args.random_crops == 1:
       crop_aug = GroupCenterCrop(args.crop_size)
    elif args.random_crops == 3:
       crop_aug = GroupFullResSample(args.crop_size, args.scale_size, flip=False)
    elif args.random_crops == 5:
       crop_aug = GroupOverSample(args.crop_size, args.scale_size, flip=False)
    else:
       crop_aug = MultiGroupRandomCrop(args.crop_size, args.random_crops),


    test_dataset = TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl=prefix,
                   multi_class=args.multi_class,
                   transform=torchvision.transforms.Compose([
                       GroupScale(int(args.scale_size)),
                       crop_aug,
                       Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                       ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
                       normalize,
                   ]), dense_sample=args.dense_sample,
                   test_mode=True,
                   temporal_clips=args.temporal_clips)


    test_loader = torch.utils.data.DataLoader(
            test_dataset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    test(test_loader, model, args.start_epoch)
Esempio n. 3
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(
        args.dataset, args.modality)
    full_arch_name = args.arch
    if args.shift:
        full_arch_name += '_shift{}_{}'.format(args.shift_div,
                                               args.shift_place)
    if args.temporal_pool:
        full_arch_name += '_tpool'
    args.store_name = '_'.join([
        'TSM', args.dataset, args.modality, full_arch_name,
        args.consensus_type,
        'segment%d' % args.num_segments, 'e{}'.format(args.epochs)
    ])
    if args.pretrain != 'imagenet':
        args.store_name += '_{}'.format(args.pretrain)
    if args.lr_type != 'step':
        args.store_name += '_{}'.format(args.lr_type)
    if args.dense_sample:
        args.store_name += '_dense'
    if args.non_local > 0:
        args.store_name += '_nl'
    if args.suffix is not None:
        args.store_name += '_{}'.format(args.suffix)
    print('storing name: ' + args.store_name)

    check_rootfolders()

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                is_shift=args.shift,
                shift_div=args.shift_div,
                shift_place=args.shift_place,
                fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
                temporal_pool=args.temporal_pool,
                non_local=args.non_local)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation(
        flip=False
        if 'something' in args.dataset or 'jester' in args.dataset else True)

    print("使用的GPU是:", args.gpus)
    #model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()      (#)
    model = model.cuda()

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.resume:
        if args.temporal_pool:  # early temporal pool so that we can load the state_dict
            make_temporal_pool(model.module.base_model, args.num_segments)
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    if args.tune_from:  #    (%)

        print(("=> fine-tuning from '{}'".format(args.tune_from)))
        sd = torch.load(args.tune_from)
        model_dict = model.state_dict()

        if args.dataset not in args.tune_from:  # new dataset
            print('=> New dataset, do not load fc weights')
            sd = {k: v for k, v in sd.items() if 'fc' not in k}
        pretrained_dict = {k: v for k, v in sd.items() if k in model_dict}

        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

    if args.temporal_pool and not args.resume:
        make_temporal_pool(model.module.base_model, args.num_segments)

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.train_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ]),
        dense_sample=args.dense_sample),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ]),
        dense_sample=args.dense_sample),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
        #criterion = FocalLoss().cuda()                 (^^)
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    if args.evaluate:  #    (*)
        print("start validate....")
        pretrained_dict = torch.load(
            'E://GSM//hmdb//hmdb-5-rgb-new1.pth')  #    (!)
        model_dict = model.state_dict()
        #for k, v in pretrained_dict.items():
        #    print(k)
        # 1. filter out unnecessary keys
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict)
        # 3. load the new state dict
        model.load_state_dict(model_dict)
        start_time = time.time()
        validate(val_loader, model, criterion, 0)
        end_time = time.time()
        print("running time is:", end_time - start_time)
        print("end valitade....")
        return

    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training,
              tf_writer)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion, epoch, log_training,
                             tf_writer)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch)

            output_best = 'Best Prec@1: %.3f\n' % (best_prec1)
            print(output_best)
            log_training.write(output_best + '\n')
            log_training.flush()

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
Esempio n. 4
0
def main():
    print(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
    global args, best_prec1
    args = parser.parse_args()

    ##asset check ####
    if args.use_finetuning:
        assert args.finetune_start_epoch > args.sup_thresh

    num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(
        args.dataset, args.modality)
    full_arch_name = args.arch
    if args.temporal_pool:
        full_arch_name += '_tpool'
    args.store_name = '_'.join([
        'TCL',
        datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), args.dataset,
        full_arch_name,
        'p%.2f' % args.percentage,
        'th%.2f' % args.threshold,
        'gamma%0.2f' % args.gamma,
        'mu%0.2f' % args.mu,
        'seed%d' % args.seed,
        'seg%d' % args.num_segments,
        'bs%d' % args.batch_size, 'e{}'.format(args.epochs)
    ])
    if args.dense_sample:
        args.store_name += '_dense'
    if args.non_local > 0:
        args.store_name += '_nl'
    if args.suffix is not None:
        args.store_name += '_{}'.format(args.suffix)
    print('storing name: ' + args.store_name)

    check_rootfolders()

    args.labeled_train_list, args.unlabeled_train_list = get_training_filenames(
        args.train_list)

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                second_segments=args.second_segments,
                is_shift=args.shift,
                shift_div=args.shift_div,
                shift_place=args.shift_place,
                fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
                temporal_pool=args.temporal_pool,
                non_local=args.non_local)
    print("==============model desccription=============")
    print(model)
    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation(flip=args.flip)

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.resume:
        if args.temporal_pool:  # early temporal pool so that we can load the state_dict
            make_temporal_pool(model.module.base_model, args.num_segments)
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    if args.tune_from:
        print(("=> fine-tuning from '{}'".format(args.tune_from)))
        sd = torch.load(args.tune_from)
        sd = sd['state_dict']
        model_dict = model.state_dict()
        replace_dict = []
        for k, v in sd.items():
            if k not in model_dict and k.replace('.net', '') in model_dict:
                print('=> Load after remove .net: ', k)
                replace_dict.append((k, k.replace('.net', '')))
        for k, v in model_dict.items():
            if k not in sd and k.replace('.net', '') in sd:
                print('=> Load after adding .net: ', k)
                replace_dict.append((k.replace('.net', ''), k))

        for k, k_new in replace_dict:
            sd[k_new] = sd.pop(k)
        keys1 = set(list(sd.keys()))
        keys2 = set(list(model_dict.keys()))
        set_diff = (keys1 - keys2) | (keys2 - keys1)
        print('#### Notice: keys that failed to load: {}'.format(set_diff))
        if args.dataset not in args.tune_from:  # new dataset
            print('=> New dataset, do not load fc weights')
            sd = {k: v for k, v in sd.items() if 'fc' not in k}
        model_dict.update(sd)
        model.load_state_dict(model_dict)

    if args.temporal_pool and not args.resume:
        make_temporal_pool(model.module.base_model, args.num_segments)

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    labeled_trainloader = torch.utils.data.DataLoader(
        TSNDataSet(
            args.root_path,
            args.labeled_train_list,
            unlabeled=False,
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            image_tmpl=prefix,
            second_segments=args.second_segments,
            transform=torchvision.transforms.Compose([
                train_augmentation,
                Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                ToTorchFormatTensor(
                    div=(args.arch not in ['BNInception', 'InceptionV3'])),
                normalize,
            ]),
            dense_sample=args.dense_sample),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=False)  # prevent something not % n_GPU

    unlabeled_trainloader = torch.utils.data.DataLoader(
        TSNDataSet(
            args.root_path,
            args.unlabeled_train_list,
            unlabeled=True,
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            image_tmpl=prefix,
            second_segments=args.second_segments,
            transform=torchvision.transforms.Compose([
                train_augmentation,
                Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                ToTorchFormatTensor(
                    div=(args.arch not in ['BNInception', 'InceptionV3'])),
                normalize,
            ]),
            dense_sample=args.dense_sample),
        batch_size=np.int(np.round(args.mu * args.batch_size)),
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=False)  # prevent something not % n_GPU

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.val_list,
        unlabeled=False,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        random_shift=False,
        second_segments=args.second_segments,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ]),
        dense_sample=args.dense_sample),
                                             batch_size=args.valbatchsize,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    default_start = 0
    is_finetune_lr_set = False
    for epoch in range(args.start_epoch, args.epochs):
        if args.use_finetuning and epoch >= args.finetune_start_epoch:
            args.eval_freq = args.finetune_stage_eval_freq
        if args.use_finetuning and epoch >= args.finetune_start_epoch and args.finetune_lr > 0.0 and not is_finetune_lr_set:
            args.lr = args.finetune_lr
            default_start = args.finetune_start_epoch
            is_finetune_lr_set = True
        adjust_learning_rate(optimizer,
                             epoch,
                             args.lr_type,
                             args.lr_steps,
                             default_start,
                             using_policy=True)

        # train for one epoch
        train(labeled_trainloader, unlabeled_trainloader, model, criterion,
              optimizer, epoch, log_training)

        # evaluate on validation set
        if ((epoch + 1) % args.eval_freq == 0 or epoch == (args.epochs - 1)
                or (epoch + 1) == args.finetune_start_epoch):
            prec1 = validate(val_loader, model, criterion, epoch, log_training)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)

            output_best = 'Best Prec@1: %.3f\n' % (best_prec1)
            print(output_best)
            log_training.write(output_best + '\n')
            log_training.flush()
            if args.use_finetuning and (epoch +
                                        1) == args.finetune_start_epoch:
                one_stage_pl = True
            else:
                one_stage_pl = False
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best, one_stage_pl)
Esempio n. 5
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset,
                                                                                                      args.modality)
    model = TSN(num_class, args.num_segments, args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place,
                fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
                temporal_pool=args.temporal_pool,
                non_local=args.non_local,
                cca3d = args.cca3d
                )

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation(flip=False if 'something' in args.dataset or 'jester' in args.dataset else True)

    model = model.cuda()

    if args.resume:
        if args.temporal_pool:  # early temporal pool so that we can load the state_dict
            make_temporal_pool(model.module.base_model, args.num_segments)
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume,map_location='cuda:0')
            parallel_state_dict = checkpoint['state_dict']
            cpu_state_dict={}
            for k,v in parallel_state_dict.items():
                cpu_state_dict[k[len('module.'):]] = v
            model.load_state_dict(cpu_state_dict)
            print(("=> loaded checkpoint '{}' (epoch {})".format(args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5
    preprocess=torchvision.transforms.Compose([
                       GroupScale(int(scale_size)),
                       GroupCenterCrop(crop_size),
                       Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                       ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
                       normalize,
                   ])
    val_loader = torch.utils.data.DataLoader(
        TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl=prefix,
                   random_shift=False,
                   transform = preprocess,
                   dense_sample=args.dense_sample),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=False)
    norm_param = (input_mean, input_std)
    cam_process(val_loader, model,norm_param)
def main():
        

    # options
    parser = argparse.ArgumentParser(description="TSM testing on the full validation set")
    parser.add_argument('dataset', type=str)

    # may contain splits
    parser.add_argument('--weights', type=str, default=None)
    parser.add_argument('--test_segments', type=str, default=25)
    parser.add_argument('--dense_sample', default=False, action="store_true", help='use dense sample as I3D')
    parser.add_argument('--twice_sample', default=False, action="store_true", help='use twice sample for ensemble')
    parser.add_argument('--full_res', default=False, action="store_true",
                        help='use full resolution 256x256 for test as in Non-local I3D')

    parser.add_argument('--test_crops', type=int, default=1)
    parser.add_argument('--coeff', type=str, default=None)
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
                        help='number of data loading workers (default: 8)')

    # for true test
    parser.add_argument('--test_list', type=str, default=None)
    parser.add_argument('--csv_file', type=str, default=None)

    parser.add_argument('--softmax', default=False, action="store_true", help='use softmax')

    parser.add_argument('--max_num', type=int, default=-1)
    parser.add_argument('--input_size', type=int, default=224)
    parser.add_argument('--crop_fusion_type', type=str, default='avg')
    parser.add_argument('--gpus', nargs='+', type=int, default=None)
    parser.add_argument('--img_feature_dim',type=int, default=256)
    parser.add_argument('--num_set_segments',type=int, default=1,help='TODO: select multiply set of n-frames from a video')
    parser.add_argument('--pretrain', type=str, default='imagenet')

    args = parser.parse_args()





    def accuracy(output, target, topk=(1,)):
        """Computes the precision@k for the specified values of k"""
        maxk = max(topk)
        batch_size = target.size(0)
        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


    def parse_shift_option_from_log_name(log_name):
        if 'shift' in log_name:
            strings = log_name.split('_')
            for i, s in enumerate(strings):
                if 'shift' in s:
                    break
            return True, int(strings[i].replace('shift', '')), strings[i + 1]
        else:
            return False, None, None


    weights_list = args.weights.split(',')
    test_segments_list = [int(s) for s in args.test_segments.split(',')]
    assert len(weights_list) == len(test_segments_list)
    if args.coeff is None:
        coeff_list = [1] * len(weights_list)
    else:
        coeff_list = [float(c) for c in args.coeff.split(',')]

    if args.test_list is not None:
        test_file_list = args.test_list.split(',')
    else:
        test_file_list = [None] * len(weights_list)


    data_iter_list = []
    net_list = []
    modality_list = []

    total_num = None
    for this_weights, this_test_segments, test_file in zip(weights_list, test_segments_list, test_file_list):
        is_shift, shift_div, shift_place = parse_shift_option_from_log_name(this_weights)
        if 'RGB' in this_weights:
            modality = 'RGB'
        else:
            modality = 'Flow'
        this_arch = this_weights.split('TSM_')[1].split('_')[2]
        modality_list.append(modality)
        num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset(args.dataset,
                                                                                                modality)
        print('=> shift: {}, shift_div: {}, shift_place: {}'.format(is_shift, shift_div, shift_place))
        net = TSN(num_class, this_test_segments if is_shift else 1, modality,
                base_model=this_arch,
                consensus_type=args.crop_fusion_type,
                img_feature_dim=args.img_feature_dim,
                pretrain=args.pretrain,
                is_shift=is_shift, shift_div=shift_div, shift_place=shift_place,
                non_local='_nl' in this_weights,
                )

        if 'tpool' in this_weights:
            from ops.temporal_shift import make_temporal_pool
            make_temporal_pool(net.base_model, this_test_segments)  # since DataParallel

        checkpoint = torch.load(this_weights)
        checkpoint = checkpoint['state_dict']

        # base_dict = {('base_model.' + k).replace('base_model.fc', 'new_fc'): v for k, v in list(checkpoint.items())}
        base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())}
        replace_dict = {'base_model.classifier.weight': 'new_fc.weight',
                        'base_model.classifier.bias': 'new_fc.bias',
                        }
        for k, v in replace_dict.items():
            if k in base_dict:
                base_dict[v] = base_dict.pop(k)

        net.load_state_dict(base_dict)

        input_size = net.scale_size if args.full_res else net.input_size
        if args.test_crops == 1:
            cropping = torchvision.transforms.Compose([
                GroupScale(net.scale_size),
                GroupCenterCrop(input_size),
            ])
        elif args.test_crops == 3:  # do not flip, so only 5 crops
            cropping = torchvision.transforms.Compose([
                GroupFullResSample(input_size, net.scale_size, flip=False)
            ])
        elif args.test_crops == 5:  # do not flip, so only 5 crops
            cropping = torchvision.transforms.Compose([
                GroupOverSample(input_size, net.scale_size, flip=False)
            ])
        elif args.test_crops == 10:
            cropping = torchvision.transforms.Compose([
                GroupOverSample(input_size, net.scale_size)
            ])
        else:
            raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(args.test_crops))

        data_loader = torch.utils.data.DataLoader(
                TSNDataSet(root_path, test_file if test_file is not None else val_list, num_segments=this_test_segments,
                        new_length=1 if modality == "RGB" else 5,
                        modality=modality,
                        image_tmpl=prefix,
                        test_mode=True,
                        remove_missing=len(weights_list) == 1,
                        transform=torchvision.transforms.Compose([
                            cropping,
                            Stack(roll=(this_arch in ['BNInception', 'InceptionV3'])),
                            ToTorchFormatTensor(div=(this_arch not in ['BNInception', 'InceptionV3'])),
                            GroupNormalize(net.input_mean, net.input_std),
                        ]), dense_sample=args.dense_sample, twice_sample=args.twice_sample),
                batch_size=args.batch_size, shuffle=False,
                num_workers=args.workers, pin_memory=True,
        )

        if args.gpus is not None:
            devices = [args.gpus[i] for i in range(args.workers)]
        else:
            devices = list(range(args.workers))

        net = torch.nn.DataParallel(net.cuda())
        net.eval()

        data_gen = enumerate(data_loader)

        if total_num is None:
            total_num = len(data_loader.dataset)
        else:
            assert total_num == len(data_loader.dataset)

        data_iter_list.append(data_gen)
        net_list.append(net)


    output = []


    def eval_video(video_data, net, this_test_segments, modality):
        net.eval()
        with torch.no_grad():
            i, data, label = video_data
            batch_size = label.numel()
            num_crop = args.test_crops
            if args.dense_sample:
                num_crop *= 10  # 10 clips for testing when using dense sample

            if args.twice_sample:
                num_crop *= 2

            if modality == 'RGB':
                length = 3
            elif modality == 'Flow':
                length = 10
            elif modality == 'RGBDiff':
                length = 18
            else:
                raise ValueError("Unknown modality "+ modality)

            data_in = data.view(-1, length, data.size(2), data.size(3))
            if is_shift:
                data_in = data_in.view(batch_size * num_crop, this_test_segments, length, data_in.size(2), data_in.size(3))
            rst = net(data_in)
            rst = rst.reshape(batch_size, num_crop, -1).mean(1)

            if args.softmax:
                # take the softmax to normalize the output to probability
                rst = F.softmax(rst, dim=1)

            rst = rst.data.cpu().numpy().copy()

            if net.module.is_shift:
                rst = rst.reshape(batch_size, num_class)
            else:
                rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class))

            return i, rst, label


    proc_start_time = time.time()
    max_num = args.max_num if args.max_num > 0 else total_num

    top1 = AverageMeter()
    top5 = AverageMeter()

    for i, data_label_pairs in enumerate(zip(*data_iter_list)):
        with torch.no_grad():
            if i >= max_num:
                break
            this_rst_list = []
            this_label = None
            for n_seg, (_, (data, label)), net, modality in zip(test_segments_list, data_label_pairs, net_list, modality_list):
                rst = eval_video((i, data, label), net, n_seg, modality)
                this_rst_list.append(rst[1])
                this_label = label
            assert len(this_rst_list) == len(coeff_list)
            for i_coeff in range(len(this_rst_list)):
                this_rst_list[i_coeff] *= coeff_list[i_coeff]
            ensembled_predict = sum(this_rst_list) / len(this_rst_list)

            for p, g in zip(ensembled_predict, this_label.cpu().numpy()):
                output.append([p[None, ...], g])
            cnt_time = time.time() - proc_start_time
            prec1, prec5 = accuracy(torch.from_numpy(ensembled_predict), this_label, topk=(1, 5))
            top1.update(prec1.item(), this_label.numel())
            top5.update(prec5.item(), this_label.numel())
            if i % 20 == 0:
                print('video {} done, total {}/{}, average {:.3f} sec/video, '
                    'moving Prec@1 {:.3f} Prec@5 {:.3f}'.format(i * args.batch_size, i * args.batch_size, total_num,
                                                                float(cnt_time) / (i+1) / args.batch_size, top1.avg, top5.avg))

    video_pred = [np.argmax(x[0]) for x in output]
    video_pred_top5 = [np.argsort(np.mean(x[0], axis=0).reshape(-1))[::-1][:5] for x in output]

    video_labels = [x[1] for x in output]


    if args.csv_file is not None:
        print('=> Writing result to csv file: {}'.format(args.csv_file))
        with open(test_file_list[0].replace('test_videofolder.txt', 'category.txt')) as f:
            categories = f.readlines()
        categories = [f.strip() for f in categories]
        with open(test_file_list[0]) as f:
            vid_names = f.readlines()
        vid_names = [n.split(' ')[0] for n in vid_names]
        print(vid_names)
        assert len(vid_names) == len(video_pred)
        if args.dataset != 'somethingv2':  # only output top1
            with open(args.csv_file, 'w') as f:
                for n, pred in zip(vid_names, video_pred):
                    f.write('{};{}\n'.format(n, categories[pred]))
        else:
            with open(args.csv_file, 'w') as f:
                for n, pred5 in zip(vid_names, video_pred_top5):
                    fill = [n]
                    for p in list(pred5):
                        fill.append(p)
                    f.write('{};{};{};{};{};{}\n'.format(*fill))


    cf = confusion_matrix(video_labels, video_pred).astype(float)

    np.save('cm.npy', cf)
    cls_cnt = cf.sum(axis=1)
    cls_hit = np.diag(cf)

    cls_acc = cls_hit / cls_cnt
    print(cls_acc)
    upper = np.mean(np.max(cf, axis=1) / cls_cnt)
    print('upper bound: {}'.format(upper))

    print('-----Evaluation is finished------')
    print('Class Accuracy {:.02f}%'.format(np.mean(cls_acc) * 100))
    print('Overall Prec@1 {:.02f}% Prec@5 {:.02f}%'.format(top1.avg, top5.avg))
    def __init__(self,
                 checkpoint_file,
                 num_classes,
                 max_length=8,
                 trim_net=False,
                 checkpoint_is_model=False,
                 bottleneck_size=128):
        self.is_shift = None
        self.net = None
        self.arch = None
        self.num_classes = num_classes
        self.max_length = max_length
        self.bottleneck_size = bottleneck_size
        #self.feature_idx = feature_idx

        self.transform = None

        self.CNN_FEATURE_COUNT = [256, 512, 1024, 2048]

        # input variables
        this_test_segments = self.max_length
        test_file = None

        #model variables
        self.is_shift, shift_div, shift_place = True, 8, 'blockres'

        self.arch = 'resnet101'
        modality = 'RGB'

        # dataset variables
        num_class, train_list, val_list, root_path, prefix = dataset_config.return_dataset(
            'somethingv2', modality)
        print('=> shift: {}, shift_div: {}, shift_place: {}'.format(
            self.is_shift, shift_div, shift_place))

        # define model
        net = TSN(
            num_class,
            this_test_segments if self.is_shift else 1,
            modality,
            base_model=self.arch,
            consensus_type='avg',
            img_feature_dim=256,
            pretrain='imagenet',
            is_shift=self.is_shift,
            shift_div=shift_div,
            shift_place=shift_place,
            non_local='_nl' in checkpoint_file,
        )
        '''
        The checkpoint file appears to be an entire TSMBackBone Object. this needs to be
        handled acordingly. Either find a way to convert it back to a weights file or maniuplate it 
        to work with the system.
        '''

        # load checkpoint file
        checkpoint = torch.load(checkpoint_file)
        '''
        #include
        print("self.bottleneck_size:", self.bottleneck_size, type(self.bottleneck_size))
        net.base_model.avgpool = nn.Sequential(
            nn.Conv2d(2048, self.bottleneck_size, (1,1)),
            nn.ReLU(inplace=True),
            #nn.AdaptiveAvgPool2d(output_size=1)
        )

        if(not trim_net):
            print("no trim")
            net.new_fc = nn.Linear(self.bottleneck_size, 174)
        else:
            print("trim")
            net.consensus = nn.Identity()
            net.new_fc = nn.Identity()

        net.base_model.fc = nn.Identity() # sets the dropout value to None
        print(net) 
        
        # Combine network together so that the it can have parameters set correctly
        # I think, I'm not 100% what this code section actually does and I don't have 
        # the time to figure it out right now
        #print("checkpoint------------------------")
        #print(checkpoint)
		'''
        if (checkpoint_is_model):
            checkpoint = checkpoint.net.state_dict()
        else:
            checkpoint = checkpoint['state_dict']

        base_dict = {
            '.'.join(k.split('.')[1:]): v
            for k, v in list(checkpoint.items())
        }
        '''
        #include
        replace_dict = {'base_model.classifier.weight': 'new_fc.weight',
                        'base_model.classifier.bias': 'new_fc.bias',
                        }
        for k, v in replace_dict.items():
            if v in base_dict:
                base_dict.pop(v)
            if k in base_dict:
                base_dict.pop(k)
                #base_dict[v] = base_dict.pop(k)
		'''
        net.load_state_dict(base_dict, strict=False)

        # define image modifications
        self.transform = torchvision.transforms.Compose([
            torchvision.transforms.Compose([
                GroupScale(net.scale_size),
                GroupCenterCrop(net.scale_size),
            ]),
            #torchvision.transforms.Compose([ GroupFullResSample(net.scale_size, net.scale_size, flip=False) ]),
            Stack(roll=(self.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(self.arch not in ['BNInception', 'InceptionV3'])),
            GroupNormalize(net.input_mean, net.input_std),
        ])

        # place net onto GPU and finalize network
        self.model = net
        net = torch.nn.DataParallel(net.cuda())
        net.eval()

        # network variable
        self.net = net

        # loss variable (used for generating gradients when ranking)
        if (not trim_net):
            self.loss = torch.nn.CrossEntropyLoss().cuda()
Esempio n. 8
0
def main():
    global args, best_prec1
    global crop_size
    args = parser.parse_args()

    num_class, train_list, val_list, args.root_path, prefix = dataset_config.return_dataset(
        args.dataset, args.modality)
    num_class = 1
    if args.train_list == "":
        args.train_list = train_list
    if args.val_list == "":
        args.val_list = val_list

    full_arch_name = args.arch
    if args.shift:
        full_arch_name += '_shift{}_{}'.format(args.shift_div,
                                               args.shift_place)
    if args.concat != "":
        full_arch_name += '_concat{}'.format(args.concat)
    if args.temporal_pool:
        full_arch_name += '_tpool'
    args.store_name += '_'.join([
        'TSM', args.dataset, args.modality, full_arch_name,
        args.consensus_type,
        'lr%.5f' % args.lr,
        'dropout%.2f' % args.dropout,
        'wd%.5f' % args.weight_decay,
        'batch%d' % args.batch_size,
        'segment%d' % args.num_segments, 'e{}'.format(args.epochs)
    ])
    if args.data_fuse:
        args.store_name += '_fuse'
    if args.extra_temporal_modeling:
        args.store_name += '_extra'
    if args.tune_from is not None:
        args.store_name += '_finetune'
    if args.pretrain != 'imagenet':
        args.store_name += '_{}'.format(args.pretrain)
    if args.lr_type != 'step':
        args.store_name += '_{}'.format(args.lr_type)
    if args.dense_sample:
        args.store_name += '_dense'
    if args.non_local > 0:
        args.store_name += '_nl'
    if args.clipnums:
        args.store_name += "_clip{}".format(args.clipnums)
    if args.suffix is not None:
        args.store_name += '_{}'.format(args.suffix)
    print('storing name: ' + args.store_name)

    check_rootfolders()

    if args.prune in ['input', 'inout'] and args.tune_from:
        sd = torch.load(args.tune_from)
        sd = sd['state_dict']
        sd = input_dim_L2distance(sd, args.shift_div)

    model = TSN(
        num_class,
        args.num_segments,
        args.modality,
        base_model=args.arch,
        new_length=2 if args.data_fuse else None,
        consensus_type=args.consensus_type,
        dropout=args.dropout,
        img_feature_dim=args.img_feature_dim,
        partial_bn=not args.no_partialbn,
        pretrain=args.pretrain,
        is_shift=args.shift,
        shift_div=args.shift_div,
        shift_place=args.shift_place,
        fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
        temporal_pool=args.temporal_pool,
        non_local=args.non_local,
        concat=args.concat,
        extra_temporal_modeling=args.extra_temporal_modeling,
        prune_list=[prune_conv1in_list, prune_conv1out_list],
        is_prune=args.prune,
    )

    print(model)
    #summary(model, torch.zeros((16, 24, 224, 224)))
    #exit(1)
    if args.dataset == 'ucf101':  #twice sample & full resolution
        twice_sample = True
        crop_size = model.scale_size  #256 x 256
    else:
        twice_sample = False
        crop_size = model.crop_size  #224 x 224
    crop_size = 256
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies(args.concat)
    #print(type(policies))
    #print(policies)
    #exit()
    train_augmentation = model.get_augmentation(
        flip=False if 'something' in args.dataset or 'jester' in args.dataset
        or 'nvgesture' in args.dataset else True)

    model = torch.nn.DataParallel(model).cuda()

    if args.resume:
        if args.temporal_pool:  # early temporal pool so that we can load the state_dict
            make_temporal_pool(model.module.base_model, args.num_segments)
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    if args.tune_from:
        print(("=> fine-tuning from '{}'".format(args.tune_from)))
        tune_from_list = args.tune_from.split(',')
        sd = torch.load(tune_from_list[0])
        sd = sd['state_dict']

        model_dict = model.state_dict()
        replace_dict = []
        for k, v in sd.items():
            if k not in model_dict and k.replace('.net', '') in model_dict:
                print('=> Load after remove .net: ', k)
                replace_dict.append((k, k.replace('.net', '')))
        for k, v in model_dict.items():
            if k not in sd and k.replace('.net', '') in sd:
                print('=> Load after adding .net: ', k)
                replace_dict.append((k.replace('.net', ''), k))
        for k, v in model_dict.items():
            if k not in sd and k.replace('.prune', '') in sd:
                print('=> Load after adding .prune: ', k)
                replace_dict.append((k.replace('.prune', ''), k))

        if args.prune in ['input', 'inout']:
            sd = adjust_para_shape_prunein(sd, model_dict)
        if args.prune in ['output', 'inout']:
            sd = adjust_para_shape_pruneout(sd, model_dict)

        if args.concat != "" and "concat" not in tune_from_list[0]:
            sd = adjust_para_shape_concat(sd, model_dict)

        for k, k_new in replace_dict:
            sd[k_new] = sd.pop(k)
        keys1 = set(list(sd.keys()))
        keys2 = set(list(model_dict.keys()))
        set_diff = (keys1 - keys2) | (keys2 - keys1)
        print('#### Notice: keys that failed to load: {}'.format(set_diff))
        if args.dataset not in tune_from_list[0]:  # new dataset
            print('=> New dataset, do not load fc weights')
            sd = {k: v for k, v in sd.items() if 'fc' not in k}
        if args.modality != 'Flow' and 'Flow' in tune_from_list[0]:
            sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k}
        #print(sd.keys())
        #print("*"*50)
        #print(model_dict.keys())
        for k, v in list(sd.items()):
            if k not in model_dict:
                sd.pop(k)
        sd.pop("module.base_model.embedding.weight")

        model_dict.update(sd)
        model.load_state_dict(model_dict)

    if args.temporal_pool and not args.resume:
        make_temporal_pool(model.module.base_model, args.num_segments)

    decoder = TransformerModel().cuda()
    if args.decoder_resume:
        decoder_chkpoint = torch.load(args.decoder_resume)

        decoder.load_state_dict(decoder_chkpoint["state_dict"])
    print("decoder parameters = ", decoder.parameters())
    policies.append({
        "params": decoder.parameters(),
        "lr_mult": 5,
        "decay_mult": 1,
        "name": "Attndecoder_weight"
    })
    cudnn.benchmark = True
    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality in ['RGB']:
        data_length = 1
    elif args.modality in ['Depth']:
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5
    '''
	dataRoot = r"/home/share/YouCook/downloadVideo"
	for dirPath, dirnames, filenames in os.walk(dataRoot):
		for filename in filenames:
			print(os.path.join(dirPath,filename) +"is {}".format("exist" if os.path.isfile(os.path.join(dirPath,filename))else "NON"))
			train_data = torchvision.io.read_video(os.path.join(dirPath,filename),start_pts=0,end_pts=1001, )
			tmp = torchvision.io.read_video_timestamps(os.path.join(dirPath,filename),)
			print(tmp)
			print(len(tmp[0]))
			print(train_data[0].size())
			exit()
	exit()
	'''
    '''
	train_loader = torch.utils.data.DataLoader(
		TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments,
				   new_length=data_length,
				   modality=args.modality,
				   image_tmpl=prefix,
				   transform=torchvision.transforms.Compose([
					   train_augmentation,
					   Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
					   ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
					   normalize,
				   ]), dense_sample=args.dense_sample, data_fuse = args.data_fuse),
		batch_size=args.batch_size, shuffle=True,
		num_workers=args.workers, pin_memory=True,
		drop_last=True)  # prevent something not % n_GPU
	

	
	val_loader = torch.utils.data.DataLoader(
		TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments,
				   new_length=data_length,
				   modality=args.modality,
				   image_tmpl=prefix,
				   random_shift=False,
				   transform=torchvision.transforms.Compose([
					   GroupScale(int(scale_size)),
					   GroupCenterCrop(crop_size),
					   Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
					   ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
					   normalize,
				   ]), dense_sample=args.dense_sample, twice_sample=twice_sample, data_fuse = args.data_fuse),
		batch_size=args.batch_size, shuffle=False,
		num_workers=args.workers, pin_memory=True)
	'''
    #global trainDataloader, valDataloader, train_loader, val_loader
    trainDataloader = YouCookDataSetRcg(args.root_path, args.train_list,train=True,inputsize=crop_size,hasPreprocess = False,\
      clipnums=args.clipnums,
      hasWordIndex = True,)
    valDataloader = YouCookDataSetRcg(args.root_path, args.val_list,val=True,inputsize=crop_size,hasPreprocess = False,\
			#clipnums=args.clipnums,

      hasWordIndex = True,)

    #print(trainDataloader._getMode())
    #print(valDataloader._getMode())
    #exit()
    train_loader = torch.utils.data.DataLoader(trainDataloader,
                                               #shuffle=True,
                                               )
    val_loader = torch.utils.data.DataLoader(valDataloader)
    index2wordDict = trainDataloader.getIndex2wordDict()
    #print(train_loader is val_loader)
    #print(trainDataloader._getMode())
    #print(valDataloader._getMode())

    #print(trainDataloader._getMode())
    #print(valDataloader._getMode())
    #print(len(train_loader))

    #exit()
    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.NLLLoss().cuda()
    elif args.loss_type == "MSELoss":
        criterion = torch.nn.MSELoss().cuda()
    elif args.loss_type == "BCELoss":
        #print("BCELoss")
        criterion = torch.nn.BCELoss().cuda()
    elif args.loss_type == "CrossEntropyLoss":
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    #print(os.path.join(args.root_log, args.store_name, 'args.txt'))
    #exit()
    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)
        #print("265")
        # train for one epoch
        ######
        #print(trainDataloader._getMode())
        #print(valDataloader._getMode())
        train(train_loader, model, decoder, criterion, optimizer, epoch,
              log_training, tf_writer, index2wordDict)
        ######
        #print("268")
        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader,
                             model,
                             decoder,
                             criterion,
                             epoch,
                             log_training,
                             tf_writer,
                             index2wordDict=index2wordDict)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch)

            output_best = 'Best Prec@1: %.3f\n' % (best_prec1)
            #print(output_best)
            log_training.write(output_best + '\n')
            log_training.flush()

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': decoder.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': best_prec1,
                },
                is_best,
                filename="decoder")
        else:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': best_prec1,
                }, False)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': decoder.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': best_prec1,
                },
                is_best,
                filename="decoder")
        #break
        print("test pass")
Esempio n. 9
0
def main():
    torch.cuda.empty_cache()
    global args, best_prec1
    args = parser.parse_args()

    #import pdb; pdb.set_trace()

    num_class, args.train_list, args.val_list, prefix = dataset_config.return_dataset(
        args.dataset, args.modality)
    full_arch_name = args.arch
    if args.shift:
        full_arch_name += '_shift{}_{}'.format(args.shift_div,
                                               args.shift_place)
    args.store_name = '_'.join([
        args.experiment_name, args.dataset, args.modality, full_arch_name,
        args.consensus_type,
        'segment%d' % args.num_segments, 'e{}'.format(args.epochs)
    ])
    if args.pretrain != 'imagenet':
        args.store_name += '_{}'.format(args.pretrain)
    if args.lr_type != 'step':
        args.store_name += '_{}'.format(args.lr_type)
    if args.dense_sample:
        args.store_name += '_dense'
    if args.suffix is not None:
        args.store_name += '_{}'.format(args.suffix)
    print('storing name: ' + args.store_name)

    check_rootfolders()

    # import pdb; pdb.set_trace()
    model = TSN(
        num_class,
        args.num_segments,
        args.modality,
        base_model=args.arch,
        consensus_type=args.consensus_type,
        dropout=args.dropout,
        img_feature_dim=args.img_feature_dim,
        partial_bn=not args.no_partialbn,
        pretrain=args.pretrain,
        is_shift=args.shift,
        shift_div=args.shift_div,
        shift_place=args.shift_place,
        fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
    )

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation(
        flip=False if 'something' in args.dataset else True)

    # import pdb; pdb.set_trace()
    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    # Add specific initialized lr and weight_decay for each group
    for param_group in policies:
        param_group['lr'] = args.lr * param_group['lr_mult']
        param_group[
            'weight_decay'] = args.weight_decay * param_group['decay_mult']

    optimizer = torch.optim.SGD(policies, momentum=args.momentum)
    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    if args.tune_from:
        print(("=> fine-tuning from '{}'".format(args.tune_from)))
        sd = torch.load(args.tune_from)
        sd = sd['state_dict']
        model_dict = model.state_dict()
        replace_dict = []
        for k, v in sd.items():
            if k not in model_dict and k.replace('.net', '') in model_dict:
                print('=> Load after remove .net: ', k)
                import pdb
                pdb.set_trace()
                replace_dict.append((k, k.replace('.net', '')))
        for k, v in model_dict.items():
            if k not in sd and k.replace('.net', '') in sd:
                print('=> Load after adding .net: ', k)
                import pdb
                pdb.set_trace()
                replace_dict.append((k.replace('.net', ''), k))

        # import pdb; pdb.set_trace()
        for k, k_new in replace_dict:
            sd[k_new] = sd.pop(k)
        keys1 = set(list(sd.keys()))
        keys2 = set(list(model_dict.keys()))
        set_diff = (keys1 - keys2) | (keys2 - keys1)
        print('#### Notice: keys that failed to load: {}'.format(set_diff))
        if args.dataset not in args.tune_from:  # new dataset
            print('=> New dataset, do not load fc weights')
            sd = {k: v for k, v in sd.items() if 'fc' not in k}
        if args.modality == 'Flow' and 'Flow' not in args.tune_from:
            sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k}
        model_dict.update(sd)
        model.load_state_dict(model_dict)

    cudnn.benchmark = True

    # Data loading code
    if not args.modality == 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['RGBDiff']:
        data_length = 5

    # import pdb; pdb.set_trace()
    train_loader = torch.utils.data.DataLoader(
        TSNDataSet(
            args.train_list,
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            image_tmpl=prefix,
            transform=torchvision.transforms.Compose([
                train_augmentation,
                Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                ToTorchFormatTensor(
                    div=(args.arch not in ['BNInception', 'InceptionV3'])),
                normalize,
            ]),
            multi_clip_test=True,
            dense_sample=args.dense_sample),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=True)  # prevent something not % n_GPU

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ]),
        multi_clip_test=False,
        dense_sample=args.dense_sample),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    # import pdb; pdb.set_trace()
    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    if args.resume:
        log_training = open(
            os.path.join(args.root_log, args.store_name, 'log_resume.csv'),
            'w')
        with open(
                os.path.join(args.root_log, args.store_name,
                             'args_resume.txt'), 'w') as f:
            f.write(str(args))
    else:
        log_training = open(
            os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
        with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
                  'w') as f:
            f.write(str(args))

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion, epoch, log_training)
            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)

            output_best = 'Best Prec@1: %.3f\n' % (best_prec1)
            print(output_best)
            log_training.write(output_best + '\n')
            log_training.flush()

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
Esempio n. 10
0
def main():
    global args, best_prec1, least_loss
    least_loss = 1000
    args = parser.parse_args()
    if os.path.exists(os.path.join(args.root_log, "error.log")):
        os.remove(os.path.join(args.root_log, "error.log"))
    logging.basicConfig(
        level=logging.DEBUG,
        filename=os.path.join(args.root_log, "error.log"),
        filemode='a',
        format=
        '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
    )

    # log_handler = open(os.path.join(args.root_log,"error.log"),"w")
    # sys.stdout = log_handler
    if args.root_path:
        num_class, args.train_list, args.val_list, _, prefix = dataset_config.return_dataset(
            args.dataset, args.modality)
        args.train_list = os.path.join(args.root_log,
                                       "kf1_train_anno_lijun_iod.json")
        args.test_list = os.path.join(args.root_log,
                                      "kf1_test_anno_lijun_iod.json")
    else:
        num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(
            args.dataset, args.modality)

    full_arch_name = args.arch
    if args.shift:
        full_arch_name += '_shift{}_{}'.format(args.shift_div,
                                               args.shift_place)
    if args.temporal_pool:
        full_arch_name += '_tpool'
    args.store_name = '_'.join([
        'TSA', args.dataset, args.modality, full_arch_name,
        args.consensus_type,
        'segment%d' % args.num_segments, 'e{}'.format(args.epochs)
    ])
    # if args.pretrain != 'imagenet':
    #     args.store_name += '_{}'.format(args.pretrain)
    if args.lr_type != 'step':
        args.store_name += '_{}'.format(args.lr_type)
    if args.dense_sample:
        args.store_name += '_dense'
    if args.non_local > 0:
        args.store_name += '_nl'
    if args.suffix is not None:
        args.store_name += '_{}'.format(args.suffix)
    print('storing name: ' + args.store_name)

    check_rootfolders()

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                is_shift=args.shift,
                shift_div=args.shift_div,
                shift_place=args.shift_place,
                fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
                temporal_pool=args.temporal_pool,
                non_local=args.non_local,
                is_TSA=args.tsa,
                is_sTSA=args.stsa,
                is_tTSA=args.ttsa,
                shift_diff=args.shift_diff,
                shift_groups=args.shift_groups,
                is_ME=args.me,
                is_3D=args.is_3D,
                cfg_file=args.cfg_file)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    # policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation(
        flip=False
        if 'something' in args.dataset or 'jester' in args.dataset else True)

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()
    if args.optimizer == "sgd":
        if args.lr_scheduler:
            optimizer = torch.optim.SGD(model.parameters(),
                                        args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
        else:
            optimizer = torch.optim.SGD(policies,
                                        args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
    elif args.optimizer == "adam":
        params = get_vmz_fine_tuning_parameters(model,
                                                args.vmz_tune_last_k_layer)
        optimizer = torch.optim.Adam(params,
                                     args.lr,
                                     weight_decay=args.weight_decay)
    else:
        raise RuntimeError("not supported optimizer")

    if args.lr_scheduler:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, args.lr_steps, args.lr_scheduler_gamma)

    if args.resume:
        if args.temporal_pool:  # early temporal pool so that we can load the state_dict
            make_temporal_pool(model.module.base_model, args.num_segments)
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))

            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            # if args.lr_scheduler:
            #     scheduler.load_state_dict(checkpoint["lr_scheduler"])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
            logging.info(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))
            logging.error(
                ("=> no checkpoint found at '{}'".format(args.resume)))

    if args.tune_from:
        print(("=> fine-tuning from '{}'".format(args.tune_from)))
        sd = torch.load(args.tune_from)
        sd = sd['state_dict']
        model_dict = model.state_dict()
        replace_dict = []
        for k, v in sd.items():
            if k not in model_dict and k.replace('.net', '') in model_dict:
                print('=> Load after remove .net: ', k)
                replace_dict.append((k, k.replace('.net', '')))
        for k, v in model_dict.items():
            if k not in sd and k.replace('.net', '') in sd:
                print('=> Load after adding .net: ', k)
                replace_dict.append((k.replace('.net', ''), k))

        for k, k_new in replace_dict:
            sd[k_new] = sd.pop(k)
        keys1 = set(list(sd.keys()))
        keys2 = set(list(model_dict.keys()))
        set_diff = (keys1 - keys2) | (keys2 - keys1)
        print('#### Notice: keys that failed to load: {}'.format(set_diff))
        # sd = {k:v for k, v in sd.items() if k in keys2}
        sd = {k: v for k, v in sd.items() if k in keys2}
        if args.dataset not in args.tune_from:  # new dataset
            print('=> New dataset, do not load fc weights')
            sd = {
                k: v
                for k, v in sd.items()
                if 'fc' not in k and "projection" not in k
            }
        if args.modality == 'Flow' and 'Flow' not in args.tune_from:
            sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k}
        model_dict.update(sd)
        model.load_state_dict(model_dict)

    if args.temporal_pool and not args.resume:
        make_temporal_pool(model.module.base_model, args.num_segments)

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality in ['RGB', "PoseAction"]:
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5
    if not args.shuffle:
        train_loader = torch.utils.data.DataLoader(
            TSNDataSet(args.root_path,
                       args.train_list,
                       num_segments=args.num_segments,
                       new_length=data_length,
                       modality=args.modality,
                       image_tmpl=prefix,
                       transform=torchvision.transforms.Compose([
                           train_augmentation,
                           Stack(roll=(args.arch
                                       in ['BNInception', 'InceptionV3']),
                                 inc_dim=(args.arch in ["R2plus1D", "X3D"])),
                           ToTorchFormatTensor(
                               div=(args.arch
                                    not in ['BNInception', 'InceptionV3']),
                               inc_dim=(args.arch in ["R2plus1D", "X3D"])),
                           normalize,
                       ]),
                       dense_sample=args.dense_sample,
                       all_sample=args.all_sample),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=True,
            drop_last=True)  # prevent something not % n_GPU

        val_loader = torch.utils.data.DataLoader(TSNDataSet(
            args.root_path,
            args.val_list,
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            image_tmpl=prefix,
            random_shift=False,
            transform=torchvision.transforms.Compose([
                GroupScale(scale_size),
                GroupCenterCrop(crop_size),
                Stack(roll=(args.arch in ['BNInception', 'InceptionV3']),
                      inc_dim=(args.arch in ["R2plus1D", "X3D"])),
                ToTorchFormatTensor(div=(args.arch
                                         not in ['BNInception',
                                                 'InceptionV3']),
                                    inc_dim=(args.arch in ["R2plus1D",
                                                           "X3D"])),
                normalize,
            ]),
            dense_sample=args.dense_sample,
            all_sample=args.all_sample),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)

    # for group in policies:
    #     print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
    #         group['name'], len(group['params']), group['lr_mult'], group['decay_mult'])))

    if args.evaluate:
        if args.loss_type == 'nll':
            criterion = torch.nn.CrossEntropyLoss().cuda()
        elif args.loss_type == "bce":
            criterion = torch.nn.BCEWithLogitsLoss().cuda()
        elif args.loss_type == "wbce":
            class_weight, pos_weight = prep_weight(args.train_list)
            criterion = WeightedBCEWithLogitsLoss(class_weight, pos_weight)
        else:
            raise ValueError("Unknown loss type")

        val_loader = torch.utils.data.DataLoader(TSNDataSet(
            args.root_path,
            args.val_list,
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            image_tmpl=prefix,
            random_shift=False,
            transform=torchvision.transforms.Compose([
                GroupScale(scale_size),
                GroupCenterCrop(crop_size),
                Stack(roll=(args.arch in ['BNInception', 'InceptionV3']),
                      inc_dim=(args.arch in ["R2plus1D", "X3D"])),
                ToTorchFormatTensor(div=(args.arch
                                         not in ['BNInception',
                                                 'InceptionV3']),
                                    inc_dim=(args.arch in ["R2plus1D",
                                                           "X3D"])),
                normalize,
            ]),
            dense_sample=args.dense_sample,
            all_sample=args.all_sample,
            analysis=True),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
        test(val_loader, model, criterion, 0)
        return

    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))

    print(model)
    logging.info(model)
    for epoch in range(args.start_epoch, args.epochs):
        logging.info("Train Epoch {}/{} starts, estimated time 5832s".format(
            str(epoch), str(args.epochs)))
        # update data_loader
        if args.shuffle:
            gen_label(args.prop_path,
                      args.label_path,
                      args.trn_name,
                      args.train_list,
                      args.neg_rate,
                      STR=False)
            gen_label(args.prop_path,
                      args.label_path,
                      args.tst_name,
                      args.val_list,
                      args.test_rate,
                      STR=False)
            train_loader = torch.utils.data.DataLoader(
                TSNDataSet(args.root_path,
                           args.train_list,
                           num_segments=args.num_segments,
                           new_length=data_length,
                           modality=args.modality,
                           image_tmpl=prefix,
                           transform=torchvision.transforms.Compose([
                               train_augmentation,
                               Stack(roll=(args.arch
                                           in ['BNInception', 'InceptionV3']),
                                     inc_dim=(args.arch in ["R2plus1D",
                                                            "X3D"])),
                               ToTorchFormatTensor(
                                   div=(args.arch
                                        not in ['BNInception', 'InceptionV3']),
                                   inc_dim=(args.arch in ["R2plus1D", "X3D"])),
                               normalize,
                           ]),
                           dense_sample=args.dense_sample,
                           all_sample=args.all_sample),
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True,
                drop_last=True)

            val_loader = torch.utils.data.DataLoader(
                TSNDataSet(args.root_path,
                           args.val_list,
                           num_segments=args.num_segments,
                           new_length=data_length,
                           modality=args.modality,
                           image_tmpl=prefix,
                           random_shift=False,
                           transform=torchvision.transforms.Compose([
                               GroupScale(scale_size),
                               GroupCenterCrop(crop_size),
                               Stack(roll=(args.arch
                                           in ['BNInception', 'InceptionV3']),
                                     inc_dim=(args.arch in ["R2plus1D",
                                                            "X3D"])),
                               ToTorchFormatTensor(
                                   div=(args.arch
                                        not in ['BNInception', 'InceptionV3']),
                                   inc_dim=(args.arch in ["R2plus1D", "X3D"])),
                               normalize,
                           ]),
                           dense_sample=args.dense_sample,
                           all_sample=args.all_sample),
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=True)
            print(train_loader)

            # define loss function (criterion) and optimizer
        if args.loss_type == 'nll':
            criterion = torch.nn.CrossEntropyLoss().cuda()
        elif args.loss_type == "bce":
            criterion = torch.nn.BCEWithLogitsLoss().cuda()
        elif args.loss_type == "wbce":
            class_weight, pos_weight = prep_weight(args.train_list)
            criterion = WeightedBCEWithLogitsLoss(class_weight, pos_weight)
        else:
            raise ValueError("Unknown loss type")

        if not args.lr_scheduler:
            adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)
            train(train_loader, model, criterion, optimizer, epoch,
                  log_training, tf_writer)

        else:
            train(train_loader, model, criterion, optimizer, epoch,
                  log_training, tf_writer)
            scheduler.step()

        # train for one epoch

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            logging.info(
                "Test Epoch {}/{} starts, estimated time 13874s".format(
                    str(epoch // args.eval_freq),
                    str(args.epochs / args.eval_freq)))
            if args.loss_type == "wbce":
                # class_weight,pos_weight = prep_weight(args.val_list)
                criterion = torch.nn.BCEWithLogitsLoss().cuda()
            lossm = validate(val_loader, model, criterion, epoch, log_training,
                             tf_writer)

            # remember best prec@1 and save checkpoint
            is_best = lossm < least_loss
            least_loss = min(lossm, least_loss)
            tf_writer.add_scalar('lss/test_top1_best', least_loss, epoch)

            output_best = 'Best Loss: %.3f\n' % (lossm)
            logging.info(output_best)
            log_training.write(output_best + '\n')
            log_training.flush()
            if args.lr_scheduler:
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'best_prec1': least_loss,
                        'lr_scheduler': scheduler,
                    }, is_best, epoch)
            else:
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'best_prec1': least_loss,
                    }, is_best, epoch)
def main():
    global args, best_prec1
    args = parser.parse_args()

    distr.init_process_group(backend='nccl',
                             init_method=args.init_method,
                             rank=args.rank,
                             world_size=args.world_size,
                             timeout=datetime.timedelta(hours=1.))

    num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(
        args.dataset, args.modality)

    args.modality = args.modality.split(',')
    full_arch_name = args.arch
    if args.shift:
        full_arch_name += '_shift{}_{}'.format(args.shift_div,
                                               args.shift_place)
    if args.temporal_pool:
        full_arch_name += '_tpool'
    args.store_name = '_'.join([
        'TSM', args.dataset, args.modality[args.rank], full_arch_name,
        args.consensus_type,
        'segment%d' % args.num_segments, 'e{}'.format(args.epochs)
    ])
    if args.pretrain != 'imagenet':
        args.store_name += '_{}'.format(args.pretrain)
    if args.lr_type != 'step':
        args.store_name += '_{}'.format(args.lr_type)
    if args.dense_sample:
        args.store_name += '_dense{}'.format(args.dense_length)
    elif args.random_sample:
        args.store_name += '_random{}'.format(args.dense_length)
    if args.non_local > 0:
        args.store_name += '_nl'
    if args.suffix is not None:
        args.store_name += '_{}'.format(args.suffix)
    if len(args.modality) > 1:
        args.store_name += '_ML{}'.format(args.rank)
    print('storing name: ' + args.store_name)

    check_rootfolders()

    model = TSN(num_class,
                args.num_segments,
                args.modality[args.rank],
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                is_shift=args.shift,
                shift_div=args.shift_div,
                shift_place=args.shift_place,
                fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
                temporal_pool=args.temporal_pool,
                non_local=args.non_local)

    crop_size = model.crop_size
    scale_size = model.scale_size
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation(
        flip=False
        if 'something' in args.dataset or 'jester' in args.dataset else True)

    model = torch.nn.DataParallel(model,
                                  device_ids=args.gpus).cuda(args.gpus[0])

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.resume:
        if args.temporal_pool:  # early temporal pool so that we can load the state_dict
            make_temporal_pool(model.module.base_model, args.num_segments)
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    if args.tune_from:
        print(("=> fine-tuning from '{}'".format(args.tune_from)))
        sd = torch.load(args.tune_from)
        sd = sd['state_dict']
        model_dict = model.state_dict()
        replace_dict = []
        for k, v in sd.items():
            if k not in model_dict and k.replace('.net', '') in model_dict:
                print('=> Load after remove .net: ', k)
                replace_dict.append((k, k.replace('.net', '')))
        for k, v in model_dict.items():
            if k not in sd and k.replace('.net', '') in sd:
                print('=> Load after adding .net: ', k)
                replace_dict.append((k.replace('.net', ''), k))

        for k, k_new in replace_dict:
            sd[k_new] = sd.pop(k)
        keys1 = set(list(sd.keys()))
        keys2 = set(list(model_dict.keys()))
        set_diff = (keys1 - keys2) | (keys2 - keys1)
        print('#### Notice: keys that failed to load: {}'.format(set_diff))
        if args.dataset not in args.tune_from:  # new dataset
            print('=> New dataset, do not load fc weights')
            sd = {k: v for k, v in sd.items() if 'fc' not in k}
        model_dict.update(sd)

        if args.modality[args.rank] not in args.tune_from or (
                args.modality[args.rank] == 'RGB'
                and 'RGBDiff' in args.tune_from):
            if 'Flow' in args.tune_from:
                model._construct_flow_model(model.base_model)
            elif 'RGBDiff' in args.tune_from:
                model._construct_diff_model(model.base_model)
            else:
                model._construct_rgb_model(model.base_model)
            model.load_state_dict(model_dict)
            if args.modality[args.rank] == 'Flow':
                model._construct_flow_model(model.base_model)
            elif args.modality[args.rank] == 'RGBDiff':
                model._construct_diff_model(model.base_model)
            else:
                model._construct_rgb_model(model.base_model)
        else:
            model.load_state_dict(model_dict)

    if args.temporal_pool and not args.resume:
        make_temporal_pool(model.module.base_model, args.num_segments)

    cudnn.benchmark = True

    # Data loading code
    train_loader = None
    if args.rank == 0:
        input_mean = []
        input_std = []
        data_length = []
        for moda in args.modality:
            if moda == 'RGB':
                input_mean += [0.485, 0.456, 0.406]
                input_std += [0.229, 0.224, 0.225]
                data_length += [1]
            elif moda == 'Flow':
                input_mean += [0.5] * 10
                input_std += [0.226] * 10
                data_length += [5]
            elif moda == 'RGBDiff':
                input_mean += [0.] * 18
                input_std += [1.] * 18
                data_length += [6]

        normalize = GroupNormalize(input_mean, input_std)
        train_loader = torch.utils.data.DataLoader(
            TSNDataSet(
                args.root_path,
                args.train_list,
                num_segments=args.num_segments,
                modality=args.modality,
                new_length=data_length,
                image_tmpl=prefix,
                transform=torchvision.transforms.Compose([
                    train_augmentation,
                    Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                    ToTorchFormatTensor(
                        div=(args.arch not in ['BNInception', 'InceptionV3'])),
                    normalize,
                ]),
                dense_sample=args.dense_sample,
                random_sample=args.random_sample,
                dense_length=args.dense_length),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=True,
            drop_last=True)  # prevent something not % n_GPU

    if args.modality[args.rank] == 'RGB':
        normalize_val = GroupNormalize([0.485, 0.456, 0.406],
                                       [0.229, 0.224, 0.225])
    elif args.modality[args.rank] == 'Flow':
        normalize_val = GroupNormalize([0.5] * 10, [0.226] * 10)
    elif args.modality[args.rank] == 'RGBDiff':
        normalize_val = IdentityTransform()

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        [args.root_path[args.rank]], [args.val_list[args.rank]],
        num_segments=args.num_segments,
        new_length=[data_length[args.rank]],
        modality=[args.modality[args.rank]],
        image_tmpl=[prefix[args.rank]],
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize_val,
        ]),
        dense_sample=args.dense_sample,
        dense_length=args.dense_length),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda(args.gpus[-1])
    else:
        raise ValueError("Unknown loss type")

    if len(args.modality) > 1:
        kl_loss = torch.nn.KLDivLoss(reduction='batchmean').cuda(args.gpus[-1])
        logsoftmax = torch.nn.LogSoftmax(dim=1).cuda(args.gpus[-1])
        softmax = torch.nn.Softmax(dim=1).cuda(args.gpus[-1])
    else:
        kl_loss = None
        logsoftmax = None
        softmax = None

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, kl_loss, logsoftmax, softmax,
              optimizer, epoch, log_training, tf_writer)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion, epoch, log_training,
                             tf_writer)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch)

            output_best = 'Best Prec@1: %.3f\n' % (best_prec1)
            print(output_best)
            log_training.write(output_best + '\n')
            log_training.flush()

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
Esempio n. 12
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(
        args.dataset, args.modality)
    full_arch_name = args.arch
    if args.temporal_pool:
        full_arch_name += '_tpool'
    args.store_name = '_'.join([
        'I3D', args.dataset, full_arch_name, 'batch{}'.format(args.batch_size),
        'wd{}'.format(args.weight_decay), args.consensus_type,
        'segment%d' % args.num_segments, 'e{}'.format(args.epochs),
        'dropout{}'.format(args.dropout), args.pretrain,
        'lr{}'.format(args.lr), '_warmup{}'.format(args.warmup)
    ])
    if args.lr_type != 'step':
        args.store_name += '_{}'.format(args.lr_type)
    else:
        step_str = [str(int(x)) for x in args.lr_steps]
        args.store_name += '_step' + '_'.join(step_str)
    if args.dense_sample:
        args.store_name += '_dense'
    if args.spatial_dropout:
        sigmoid_layer_str = '_'.join(args.sigmoid_layer)
        args.store_name += '_spatial_drop3d_{}_group{}_layer{}'.format(
            args.sigmoid_thres, args.sigmoid_group, sigmoid_layer_str)
        if args.sigmoid_random:
            args.store_name += '_RandomSigmoid'
    if args.non_local > 0:
        args.store_name += '_nl'
    if args.suffix is not None:
        args.store_name += '_{}'.format(args.suffix)
    print('storing name: ' + args.store_name)

    check_rootfolders()

    model = i3d(num_class,
                args.num_segments,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                spatial_dropout=args.spatial_dropout,
                sigmoid_thres=args.sigmoid_thres,
                sigmoid_group=args.sigmoid_group,
                sigmoid_random=args.sigmoid_random,
                sigmoid_layer=args.sigmoid_layer,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    train_augmentation = model.get_augmentation(
        flip=False
        if 'something' in args.dataset or 'jester' in args.dataset else True)

    model = torch.nn.DataParallel(model,
                                  device_ids=list(range(args.gpus))).cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.resume:
        if args.temporal_pool:  # early temporal pool so that we can load the state_dict
            make_temporal_pool(model.module.base_model, args.num_segments)
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    if args.tune_from:
        print(("=> fine-tuning from '{}'".format(args.tune_from)))
        sd = torch.load(args.tune_from)
        sd = sd['state_dict']
        model_dict = model.state_dict()
        replace_dict = []
        for k, v in sd.items():
            if k not in model_dict and k.replace('.net', '') in model_dict:
                print('=> Load after remove .net: ', k)
                replace_dict.append((k, k.replace('.net', '')))
        for k, v in model_dict.items():
            if k not in sd and k.replace('.net', '') in sd:
                print('=> Load after adding .net: ', k)
                replace_dict.append((k.replace('.net', ''), k))

        for k, k_new in replace_dict:
            sd[k_new] = sd.pop(k)
        keys1 = set(list(sd.keys()))
        keys2 = set(list(model_dict.keys()))
        set_diff = (keys1 - keys2) | (keys2 - keys1)
        print('#### Notice: keys that failed to load: {}'.format(set_diff))
        if args.dataset not in args.tune_from:  # new dataset
            print('=> New dataset, do not load fc weights')
            sd = {k: v for k, v in sd.items() if 'fc' not in k}
        if args.modality == 'Flow' and 'Flow' not in args.tune_from:
            sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k}
        model_dict.update(sd)
        model.load_state_dict(model_dict)

    if args.temporal_pool and not args.resume:
        make_temporal_pool(model.module.base_model, args.num_segments)

    cudnn.benchmark = True

    # Data loading code
    normalize = GroupNormalize(input_mean, input_std)

    train_loader = torch.utils.data.DataLoader(
        TSNDataSet(args.root_path,
                   args.train_list,
                   num_segments=args.num_segments,
                   image_tmpl=prefix,
                   transform=torchvision.transforms.Compose([
                       GroupScale((256, 340)),
                       train_augmentation,
                       Stack('3D'),
                       ToTorchFormatTensor(),
                       normalize,
                   ]),
                   dense_sample=args.dense_sample),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=True)  # prevent something not % n_GPU

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.val_list,
        num_segments=args.num_segments,
        image_tmpl=prefix,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack('3D'),
            ToTorchFormatTensor(),
            normalize,
        ]),
        dense_sample=args.dense_sample),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.BCEWithLogitsLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))
    for epoch in range(args.start_epoch, args.epochs):
        # adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training,
              tf_writer, args.warmup, args.lr_type, args.lr_steps)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion, epoch, log_training,
                             tf_writer)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            tf_writer.add_scalar('acc/test_mAP_best', best_prec1, epoch)

            output_best = 'Best mAP: %.3f\n' % (best_prec1)
            print(output_best)
            log_training.write(output_best + '\n')
            log_training.flush()

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
Esempio n. 13
0
def main():
    args = parser.parse_args()
    common.set_manual_data_path(args.data_path, args.exps_path)
    test_mode = (args.test_from != "")

    set_random_seed(args.random_seed, args)

    args.num_class, args.train_list, args.val_list, args.root_path, prefix = \
        dataset_config.return_dataset(args.dataset, args.data_path)

    if args.gpus is not None:
        print("Use GPU: {} for training".format(args.gpus))

    logger = Logger()
    sys.stdout = logger

    if args.ada_reso_skip:
        model = TSN_Gate(args=args)
    else:
        model = TSN_Ada(args=args)

    base_model_gflops, gflops_list, g_meta = init_gflops_table(model, args)
    policies = model.get_optim_policies()
    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    if test_mode or args.base_pretrained_from != "":
        the_model_path = args.base_pretrained_from
        if test_mode:
            if "pth.tar" not in args.test_from:
                the_model_path = ospj(args.test_from, "models",
                                      "ckpt.best.pth.tar")
            else:
                the_model_path = args.test_from
        the_model_path = common.EXPS_PATH + "/" + the_model_path
        sd = torch.load(the_model_path)['state_dict']
        model_dict = model.state_dict()
        model_dict.update(sd)
        model.load_state_dict(model_dict)

    cudnn.benchmark = True

    train_loader, val_loader = get_data_loaders(model, prefix, args)
    criterion = torch.nn.CrossEntropyLoss().cuda()

    exp_full_path = setup_log_directory(args.exp_header, test_mode, args,
                                        logger)
    if not test_mode:
        with open(os.path.join(exp_full_path, 'args.txt'), 'w') as f:
            f.write(str(args))

    map_record, mmap_record, prec_record, prec5_record = get_recorders(4)
    best_train_usage_str = None
    best_val_usage_str = None

    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        if not args.skip_training and not test_mode:
            set_random_seed(args.train_random_seed + epoch, args)
            adjust_learning_rate(optimizer, epoch, -1, -1, args.lr_type,
                                 args.lr_steps, args)
            train_usage_str = train(train_loader, model, criterion, optimizer,
                                    epoch, base_model_gflops, gflops_list,
                                    g_meta, args)
        else:
            train_usage_str = "(Eval mode)"
        torch.cuda.empty_cache()

        # evaluation
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            set_random_seed(args.random_seed, args)
            mAP, mmAP, prec1, prec5, val_usage_str = \
                validate(val_loader, model, criterion, epoch, base_model_gflops, gflops_list, g_meta, exp_full_path,
                         args)

            # remember best prec@1 and save checkpoint
            map_record.update(mAP)
            mmap_record.update(mmAP)
            prec_record.update(prec1)
            prec5_record.update(prec5)

            if prec_record.is_current_best():
                best_train_usage_str = train_usage_str if not args.skip_training else "(Eval Mode)"
                best_val_usage_str = val_usage_str

            print('Best Prec@1: %.3f (epoch=%d) w. Prec@5: %.3f' %
                  (prec_record.best_val, prec_record.best_at,
                   prec5_record.at(prec_record.best_at)))

            if test_mode or args.skip_training:  # only runs for one epoch
                break
            else:
                saved_things = {'state_dict': model.state_dict()}
                save_checkpoint(saved_things, prec_record.is_current_best(),
                                False, exp_full_path, "ckpt.best")
                save_checkpoint(saved_things, True, False, exp_full_path,
                                "ckpt.latest")

                if epoch in args.backup_epoch_list:
                    save_checkpoint(None, False, True, exp_full_path,
                                    str(epoch))
                torch.cuda.empty_cache()

    # after fininshing all the epochs
    if test_mode:
        if args.skip_log == False:
            os.rename(
                logger._log_path,
                ospj(
                    logger._log_dir_name,
                    logger._log_file_name[:-4] + "_mm_%.2f_a_%.2f_f.txt" %
                    (mmap_record.best_val, prec_record.best_val)))
    else:
        if args.ada_reso_skip:
            print("Best train usage:%s\nBest val usage:%s" %
                  (best_train_usage_str, best_val_usage_str))
Esempio n. 14
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(
        args.dataset, args.modality)
    full_arch_name = args.arch
    if args.deploy:
        full_arch_name += '_deploy'
    if args.shift:
        if args.soft_shift:
            full_arch_name += '_softshift{}_{}'.format(args.shift_div,
                                                       args.shift_place)
        else:
            full_arch_name += '_shift{}_{}'.format(args.shift_div,
                                                   args.shift_place)
    if args.temporal_pool:
        full_arch_name += '_tpool'

    args.store_name = '_'.join([
        'TSM', args.dataset, args.modality, full_arch_name,
        args.consensus_type,
        'segment%d' % args.num_segments, 'e{}'.format(args.epochs)
    ])
    if args.pretrain != 'imagenet':
        args.store_name += '_{}'.format(args.pretrain)
    if args.lr_type != 'step':
        args.store_name += '_{}'.format(args.lr_type)
    if args.dense_sample:
        args.store_name += '_dense'
    if args.non_local > 0:
        args.store_name += '_nl'
    if args.suffix is not None:
        args.store_name += '_{}'.format(args.suffix)
    print('storing name: ' + args.store_name)

    check_rootfolders()

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                is_shift=args.shift,
                shift_div=args.shift_div,
                shift_place=args.shift_place,
                is_softshift=args.soft_shift,
                shift_init_mode=args.shift_init_mode,
                fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
                temporal_pool=args.temporal_pool,
                non_local=args.non_local,
                deploy=args.deploy)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation(
        flip=False
        if 'something' in args.dataset or 'jester' in args.dataset else True)

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()
    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.resume:
        if args.temporal_pool:  # early temporal pool so that we can load the state_dict
            make_temporal_pool(model.module.base_model, args.num_segments)
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    if args.tune_from:
        print(("=> fine-tuning from '{}'".format(args.tune_from)))
        sd = torch.load(args.tune_from)
        sd = sd['state_dict']
        model_dict = model.state_dict()
        replace_dict = []
        for k, v in sd.items():
            if k not in model_dict and k.replace('.net', '') in model_dict:
                print('=> Load after remove .net: ', k)
                replace_dict.append((k, k.replace('.net', '')))
        for k, v in model_dict.items():
            if k not in sd and k.replace('.net', '') in sd:
                print('=> Load after adding .net: ', k)
                replace_dict.append((k.replace('.net', ''), k))

        for k, k_new in replace_dict:
            sd[k_new] = sd.pop(k)
        keys1 = set(list(sd.keys()))
        keys2 = set(list(model_dict.keys()))
        set_diff = (keys1 - keys2) | (keys2 - keys1)
        print('#### Notice: keys that failed to load: {}'.format(set_diff))
        if args.dataset not in args.tune_from:  # new dataset
            print('=> New dataset, do not load fc weights')
            sd = {k: v for k, v in sd.items() if 'fc' not in k}
        if args.modality == 'Flow' and 'Flow' not in args.tune_from:
            sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k}
        model_dict.update(sd)
        model.load_state_dict(model_dict)

    if args.temporal_pool and not args.resume:
        make_temporal_pool(model.module.base_model, args.num_segments)

    cudnn.benchmark = True

    #* Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_loader = torch.utils.data.DataLoader(
        TSNDataSet(
            args.root_path,
            args.train_list,
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            image_tmpl=prefix,
            transform=torchvision.transforms.Compose([
                train_augmentation,
                Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                ToTorchFormatTensor(
                    div=(args.arch not in ['BNInception', 'InceptionV3'])),
                normalize,
            ]),
            dense_sample=args.dense_sample),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=True)  # prevent something not % n_GPU

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ]),
        dense_sample=args.dense_sample),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    elif args.loss_type == "focalloss":
        from ops.focalloss import FocalLoss
        criterion = FocalLoss(num_class=args.num_classes).cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    # if args.evaluate:
    #     validate(val_loader, model, criterion, 0)
    #     return

    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))
    if args.lr_type == 'ReduceLROnPlateau':
        from torch.optim.lr_scheduler import ReduceLROnPlateau
        scheduler = ReduceLROnPlateau(optimizer, 'min')

    time_start = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.lr_type != 'ReduceLROnPlateau':
            adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training,
              tf_writer)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            # if 1:
            if args.lr_type != 'ReduceLROnPlateau':
                prec1 = validate(val_loader, model, criterion, epoch,
                                 log_training, tf_writer)
            else:
                prec1 = validate(val_loader,
                                 model,
                                 criterion,
                                 epoch,
                                 log_training,
                                 tf_writer,
                                 scheduler=scheduler)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch)

            output_best = 'Best Prec@1: %.3f\n' % (best_prec1)
            print(output_best)
            log_training.write(output_best + '\n')
            log_training.flush()

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)

    time_end = time.time()
    time_cost = time_end - time_start
    output_time = 'Totally cost %.3f h: %.3f min: %.3f s' % (
        time_cost / 60 / 60, time_cost / 60, time_cost % 60)
    print(output_time)

    log_training.write(output_time + '\n')
    log_training.flush()
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=rank)
    else:
        rank = 0
    # create model
    num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset, args.modality)
    full_arch_name = args.arch
    if args.shift:
        full_arch_name += '_shift{}_{}'.format(args.shift_div, args.shift_place)
    if args.temporal_pool:
        full_arch_name += '_tpool'
    args.store_name = '_'.join(
        ['TSM', args.dataset, args.modality, full_arch_name, args.consensus_type, 'segment%d' % args.num_segments,
         'e{}'.format(args.epochs)])
    if args.pretrain != 'imagenet':
        args.store_name += '_{}'.format(args.pretrain)
    if args.lr_type != 'step':
        args.store_name += '_{}'.format(args.lr_type)
    if args.dense_sample:
        args.store_name += '_dense'
    if args.non_local > 0:
        args.store_name += '_nl'
    args.store_name += '_lr{}'.format(args.lr)
    args.store_name += '_wd{:.1e}'.format(args.weight_decay)
    args.store_name += '_do{}'.format(args.dropout)
    if args.suffix is not None:
        args.store_name += '_{}'.format(args.suffix)
    print('storing name: ' + args.store_name)

    check_rootfolders(args, rank)

    model = TSN(num_class, args.num_segments, args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place,
                fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
                temporal_pool=args.temporal_pool,
                non_local=args.non_local)
    
    # first synchronization of initial weights
    # sync_initial_weights(model, args.rank, args.world_size)
    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    if rank == 0:
        print(model)
    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation(flip=False if 'something' in args.dataset or 'jester' in args.dataset else True)

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have on a node
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    optimizer = torch.optim.SGD(policies, args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            if args.start_epoch == 1:
                args.start_epoch = checkpoint['epoch'] + 1
            best_acc1 = checkpoint['best_acc1']
#             if args.gpu is not None:
#                 # best_acc1 may be from a checkpoint from a different GPU
#                 best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_dataset = TSNDataSet(args.dataset, args.root_path, args.train_list, num_segments=args.num_segments,
                       new_length=data_length,
                       modality=args.modality,
                       image_tmpl=prefix,
                       transform=torchvision.transforms.Compose([train_augmentation,
                           Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                           ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
                           normalize]), 
                      dense_sample=args.dense_sample)
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, drop_last=True, sampler=train_sampler)
    
    val_loader = torch.utils.data.DataLoader(
        TSNDataSet(args.dataset, args.root_path, args.val_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl=prefix,
                   random_shift=False,
                   transform=torchvision.transforms.Compose([
                       GroupScale(int(scale_size)),
                       GroupCenterCrop(crop_size),
                       Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                       ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
                       normalize,
                   ]), dense_sample=args.dense_sample),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return
    
    log_training = open(os.path.join(args.root_model, args.store_name, 'log.csv'), 'w')
    with open(os.path.join(args.root_model, args.store_name, 'args.txt'), 'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(log_dir=os.path.join(args.root_model, args.store_name))
    for epoch in range(args.start_epoch, args.epochs+1):
        if args.distributed:
            train_sampler.set_epoch(epoch)
#         adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps, args)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer, args, rank)
        if rank % ngpus_per_node == 0:
            save_checkpoint({
                'epoch': epoch,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer' : optimizer.state_dict(),
            }, False, args, rank)
        
        if epoch % 5 == 0 and rank % ngpus_per_node == 0:
            save_checkpoint({
                'epoch': epoch,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }, False, args, rank, e=epoch)

        # evaluate on validation set
        is_best = False
        if epoch % args.eval_freq == 0 or epoch == args.epochs:
            acc1 = validate(val_loader, model, criterion, epoch, args, rank, log_training, tf_writer)

            # remember best acc@1 and save checkpoint
            is_best = acc1 > best_acc1
            best_acc1 = max(acc1, best_acc1)

        if not args.multiprocessing_distributed or (args.multiprocessing_distributed and rank % ngpus_per_node == 0):
            save_checkpoint({
                'epoch': epoch,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer' : optimizer.state_dict(),
            }, is_best, args, rank)
Esempio n. 16
0
if args.test_list is not None:
    test_file_list = args.test_list.split(',')
else:
    test_file_list = [None] * len(weights_list)

data_iter_list = []
net_list = []
modality_list = args.modalities.split(',')
arch_list = args.archs.split('.')

total_num = None
for this_weights, this_test_segments, test_file, modality, this_arch in zip(
        weights_list, test_segments_list, test_file_list, modality_list,
        arch_list):
    num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset(
        args.dataset, modality)
    net = TSN(num_class,
              this_test_segments,
              modality,
              base_model=this_arch,
              consensus_type=args.crop_fusion_type,
              img_feature_dim=args.img_feature_dim,
              pretrain=args.pretrain)
    checkpoint = torch.load(this_weights)
    try:
        net.load_state_dict(checkpoint['state_dict'])
    except:
        checkpoint = checkpoint['state_dict']

        base_dict = {
            '.'.join(k.split('.')[1:]): v
Esempio n. 17
0
File: main.py Progetto: CV-IP/TDN
def main():
    global args, best_prec1
    args = parser.parse_args()

    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group(backend='nccl', init_method='env://')

    num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(
        args.dataset, args.modality)
    full_arch_name = args.arch
    args.store_name = '_'.join([
        'TDN_', args.dataset, args.modality, full_arch_name,
        args.consensus_type,
        'segment%d' % args.num_segments, 'e{}'.format(args.epochs)
    ])
    if args.pretrain != 'imagenet':
        args.store_name += '_{}'.format(args.pretrain)
    if args.dense_sample:
        args.store_name += '_dense'
    if args.suffix is not None:
        args.store_name += '_{}'.format(args.suffix)

    if dist.get_rank() == 0:
        check_rootfolders()

    logger = setup_logger(output=os.path.join(args.root_log, args.store_name),
                          distributed_rank=dist.get_rank(),
                          name=f'TDN')
    logger.info('storing name: ' + args.store_name)

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                fc_lr5=(args.tune_from and args.dataset in args.tune_from))

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    for group in policies:
        logger.info(
            ('[TDN-{}]group: {} has {} params, lr_mult: {}, decay_mult: {}'.
             format(args.arch, group['name'], len(group['params']),
                    group['lr_mult'], group['decay_mult'])))

    train_augmentation = model.get_augmentation(
        flip=False if 'something' in args.dataset else True)

    cudnn.benchmark = True

    # Data loading code
    normalize = GroupNormalize(input_mean, input_std)

    train_dataset = TSNDataSet(
        args.dataset,
        args.root_path,
        args.train_list,
        num_segments=args.num_segments,
        modality=args.modality,
        image_tmpl=prefix,
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ]),
        dense_sample=args.dense_sample)

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True)

    val_dataset = TSNDataSet(
        args.dataset,
        args.root_path,
        args.val_list,
        num_segments=args.num_segments,
        modality=args.modality,
        image_tmpl=prefix,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ]),
        dense_sample=args.dense_sample)

    val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             sampler=val_sampler,
                                             drop_last=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    scheduler = get_scheduler(optimizer, len(train_loader), args)

    model = DistributedDataParallel(model.cuda(),
                                    device_ids=[args.local_rank],
                                    broadcast_buffers=True,
                                    find_unused_parameters=True)

    if args.resume:
        if os.path.isfile(args.resume):
            logger.info(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume, map_location='cpu')
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            logger.info(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            logger.info(("=> no checkpoint found at '{}'".format(args.resume)))

    if args.tune_from:
        logger.info(("=> fine-tuning from '{}'".format(args.tune_from)))
        sd = torch.load(args.tune_from)
        sd = sd['state_dict']
        model_dict = model.state_dict()
        replace_dict = []
        for k, v in sd.items():
            if k not in model_dict and k.replace('.net', '') in model_dict:
                logger.info('=> Load after remove .net: ', k)
                replace_dict.append((k, k.replace('.net', '')))
        for k, v in model_dict.items():
            if k not in sd and k.replace('.net', '') in sd:
                logger.info('=> Load after adding .net: ', k)
                replace_dict.append((k.replace('.net', ''), k))

        for k, k_new in replace_dict:
            sd[k_new] = sd.pop(k)
        keys1 = set(list(sd.keys()))
        keys2 = set(list(model_dict.keys()))
        set_diff = (keys1 - keys2) | (keys2 - keys1)
        logger.info(
            '#### Notice: keys that failed to load: {}'.format(set_diff))
        if args.dataset not in args.tune_from:  # new dataset
            logger.info('=> New dataset, do not load fc weights')
            sd = {k: v for k, v in sd.items() if 'fc' not in k}
        model_dict.update(sd)
        model.load_state_dict(model_dict)

    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))

    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))

    if args.evaluate:
        logger.info(("===========evaluate==========="))
        val_loader.sampler.set_epoch(args.start_epoch)
        prec1, prec5, val_loss = validate(val_loader, model, criterion, logger)
        if dist.get_rank() == 0:
            is_best = prec1 > best_prec1
            best_prec1 = prec1
            logger.info(("Best Prec@1: '{}'".format(best_prec1)))
            save_epoch = args.start_epoch + 1
            save_checkpoint(
                {
                    'epoch': args.start_epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'prec1': prec1,
                    'best_prec1': best_prec1,
                }, save_epoch, is_best)
        return

    for epoch in range(args.start_epoch, args.epochs):
        train_loader.sampler.set_epoch(epoch)
        train_loss, train_top1, train_top5 = train(train_loader,
                                                   model,
                                                   criterion,
                                                   optimizer,
                                                   epoch=epoch,
                                                   logger=logger,
                                                   scheduler=scheduler)
        if dist.get_rank() == 0:
            tf_writer.add_scalar('loss/train', train_loss, epoch)
            tf_writer.add_scalar('acc/train_top1', train_top1, epoch)
            tf_writer.add_scalar('acc/train_top5', train_top5, epoch)
            tf_writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], epoch)

        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            val_loader.sampler.set_epoch(epoch)
            prec1, prec5, val_loss = validate(val_loader, model, criterion,
                                              epoch, logger)
            if dist.get_rank() == 0:
                tf_writer.add_scalar('loss/test', val_loss, epoch)
                tf_writer.add_scalar('acc/test_top1', prec1, epoch)
                tf_writer.add_scalar('acc/test_top5', prec5, epoch)

                is_best = prec1 > best_prec1
                best_prec1 = max(prec1, best_prec1)
                tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch)

                logger.info(("Best Prec@1: '{}'".format(best_prec1)))
                tf_writer.flush()
                save_epoch = epoch + 1
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'prec1': prec1,
                        'best_prec1': best_prec1,
                    }, save_epoch, is_best)
 for this_weights, this_test_segments, test_file in zip(
         weights_list, test_segments_list, test_file_list):
     is_shift, shift_div, shift_place = parse_shift_option_from_log_name(
         this_weights)
     if 'Flow' in this_weights:
         modality = 'Flow'
     elif 'RGB-seg' in this_weights:
         modality = 'RGB-seg'
     elif 'RGB-flo' in this_weights:
         modality = 'RGB-flo'
     elif 'RGB' in this_weights:
         modality = 'RGB'
     this_arch = this_weights.split('TSM_')[1].split('_')[2]
     ipn_no_class, bio_validation = parse_biovid_from_log_name(this_weights)
     modality_list.append(modality)
     num_class, categories, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset(
         args.dataset, modality, ipn_no_class, str(bio_validation))
     print('=> shift: {}, shift_div: {}, shift_place: {}'.format(
         is_shift, shift_div, shift_place))
     net = TSN(
         num_class,
         this_test_segments if is_shift else 1,
         modality,
         base_model=this_arch,
         consensus_type=args.crop_fusion_type,
         img_feature_dim=args.img_feature_dim,
         pretrain=args.pretrain,
         is_shift=is_shift,
         shift_div=shift_div,
         shift_place=shift_place,
         non_local='_nl' in this_weights,
     )
Esempio n. 19
0
def load_model(weights):
    global num_class
    is_shift, shift_div, shift_place = parse_shift_option_from_log_name(
        weights)
    if 'RGB' in weights:
        modality = 'RGB'
    elif 'Depth' in weights:
        modality = 'Depth'
    else:
        modality = 'Flow'

    if 'concatAll' in weights:
        concat = "All"
    elif "concatFirst" in weights:
        concat = "First"
    else:
        concat = ""

    if 'extra' in this_weights:
        extra_temporal_modeling = True

    args.prune = ""

    if 'conv1d' in weights:
        args.crop_fusion_type = "conv1d"
    else:
        args.crop_fusion_type = "avg"

    this_arch = weights.split('TSM_')[1].split('_')[2]
    modality_list.append(modality)
    num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset(
        args.dataset, modality)

    print('=> shift: {}, shift_div: {}, shift_place: {}'.format(
        is_shift, shift_div, shift_place))
    net = TSN(num_class,
              int(args.test_segments) if is_shift else 1,
              modality,
              base_model=this_arch,
              consensus_type=args.crop_fusion_type,
              img_feature_dim=args.img_feature_dim,
              pretrain=args.pretrain,
              is_shift=is_shift,
              shift_div=shift_div,
              shift_place=shift_place,
              non_local='_nl' in weights,
              concat=concat,
              extra_temporal_modeling=extra_temporal_modeling,
              prune_list=[prune_conv1in_list, prune_conv1out_list],
              is_prune=args.prune)

    if 'tpool' in weights:
        from ops.temporal_shift import make_temporal_pool
        make_temporal_pool(net.base_model,
                           args.test_segments)  # since DataParallel

    checkpoint = torch.load(weights)
    checkpoint = checkpoint['state_dict']

    # base_dict = {('base_model.' + k).replace('base_model.fc', 'new_fc'): v for k, v in list(checkpoint.items())}
    base_dict = {
        '.'.join(k.split('.')[1:]): v
        for k, v in list(checkpoint.items())
    }
    replace_dict = {
        'base_model.classifier.weight': 'new_fc.weight',
        'base_model.classifier.bias': 'new_fc.bias',
    }
    for k, v in replace_dict.items():
        if k in base_dict:
            base_dict[v] = base_dict.pop(k)

    net.load_state_dict(base_dict)

    input_size = net.scale_size if args.full_res else net.input_size
    if args.test_crops == 1:
        cropping = torchvision.transforms.Compose([
            GroupScale(net.scale_size),
            GroupCenterCrop(input_size),
        ])
    elif args.test_crops == 3:  # do not flip, so only 5 crops
        cropping = torchvision.transforms.Compose(
            [GroupFullResSample(input_size, net.scale_size, flip=False)])
    elif args.test_crops == 5:  # do not flip, so only 5 crops
        cropping = torchvision.transforms.Compose(
            [GroupOverSample(input_size, net.scale_size, flip=False)])
    elif args.test_crops == 10:
        cropping = torchvision.transforms.Compose(
            [GroupOverSample(input_size, net.scale_size)])
    else:
        raise ValueError(
            "Only 1, 5, 10 crops are supported while we got {}".format(
                args.test_crops))

    transform = torchvision.transforms.Compose([
        cropping,
        Stack(roll=(this_arch in ['BNInception', 'InceptionV3'])),
        ToTorchFormatTensor(
            div=(this_arch not in ['BNInception', 'InceptionV3'])),
        GroupNormalize(net.input_mean, net.input_std),
    ])

    if args.gpus is not None:
        devices = [args.gpus[i] for i in range(args.workers)]
    else:
        devices = list(range(args.workers))

    net = torch.nn.DataParallel(net.cuda())
    return is_shift, net, transform