Ejemplo n.º 1
0
        num_class,
        this_test_segments if is_shift else 1,
        modality,
        base_model=this_arch,
        consensus_type=args.crop_fusion_type,
        img_feature_dim=args.img_feature_dim,
        pretrain=args.pretrain,
        is_shift=is_shift,
        shift_div=shift_div,
        shift_place=shift_place,
        non_local='_nl' in this_weights,
    )

    if 'tpool' in this_weights:
        from ops.temporal_shift import make_temporal_pool
        make_temporal_pool(net.base_model,
                           this_test_segments)  # since DataParallel

    checkpoint = torch.load(this_weights)
    checkpoint = checkpoint['state_dict']

    # base_dict = {('base_model.' + k).replace('base_model.fc', 'new_fc'): v for k, v in list(checkpoint.items())}
    base_dict = {
        '.'.join(k.split('.')[1:]): v
        for k, v in list(checkpoint.items())
    }
    replace_dict = {
        'base_model.classifier.weight': 'new_fc.weight',
        'base_model.classifier.bias': 'new_fc.bias',
    }
    for k, v in replace_dict.items():
        if k in base_dict:
Ejemplo n.º 2
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(
        args.dataset, args.modality)
    full_arch_name = args.arch
    if args.shift:
        full_arch_name += '_shift{}_{}'.format(args.shift_div,
                                               args.shift_place)
    if args.temporal_pool:
        full_arch_name += '_tpool'
    args.store_name = '_'.join([
        'TSM', args.dataset, args.modality, full_arch_name,
        args.consensus_type,
        'segment%d' % args.num_segments, 'e{}'.format(args.epochs)
    ])
    if args.pretrain != 'imagenet':
        args.store_name += '_{}'.format(args.pretrain)
    if args.lr_type != 'step':
        args.store_name += '_{}'.format(args.lr_type)
    if args.dense_sample:
        args.store_name += '_dense'
    if args.non_local > 0:
        args.store_name += '_nl'
    if args.suffix is not None:
        args.store_name += '_{}'.format(args.suffix)
    print('storing name: ' + args.store_name)

    check_rootfolders()

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                is_shift=args.shift,
                shift_div=args.shift_div,
                shift_place=args.shift_place,
                fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
                temporal_pool=args.temporal_pool,
                non_local=args.non_local)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation(
        flip=False
        if 'something' in args.dataset or 'jester' in args.dataset else True)

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.resume:
        if args.temporal_pool:  # early temporal pool so that we can load the state_dict
            make_temporal_pool(model.module.base_model, args.num_segments)
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    if args.tune_from:
        print(("=> fine-tuning from '{}'".format(args.tune_from)))
        sd = torch.load(args.tune_from)
        sd = sd['state_dict']
        model_dict = model.state_dict()
        replace_dict = []
        for k, v in sd.items():
            if k not in model_dict and k.replace('.net', '') in model_dict:
                print('=> Load after remove .net: ', k)
                replace_dict.append((k, k.replace('.net', '')))
        for k, v in model_dict.items():
            if k not in sd and k.replace('.net', '') in sd:
                print('=> Load after adding .net: ', k)
                replace_dict.append((k.replace('.net', ''), k))

        for k, k_new in replace_dict:
            sd[k_new] = sd.pop(k)
        keys1 = set(list(sd.keys()))
        keys2 = set(list(model_dict.keys()))
        set_diff = (keys1 - keys2) | (keys2 - keys1)
        print('#### Notice: keys that failed to load: {}'.format(set_diff))
        if args.dataset not in args.tune_from:  # new dataset
            print('=> New dataset, do not load fc weights')
            sd = {k: v for k, v in sd.items() if 'fc' not in k}
        if args.modality == 'Flow' and 'Flow' not in args.tune_from:
            sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k}
        model_dict.update(sd)
        model.load_state_dict(model_dict)

    if args.temporal_pool and not args.resume:
        make_temporal_pool(model.module.base_model, args.num_segments)

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_loader = torch.utils.data.DataLoader(
        TSNDataSet(
            args.root_path,
            args.train_list,
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            image_tmpl=prefix,
            transform=torchvision.transforms.Compose([
                train_augmentation,
                Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                ToTorchFormatTensor(
                    div=(args.arch not in ['BNInception', 'InceptionV3'])),
                normalize,
            ]),
            dense_sample=args.dense_sample),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=True)  # prevent something not % n_GPU

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ]),
        dense_sample=args.dense_sample),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training,
              tf_writer)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion, epoch, log_training,
                             tf_writer)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch)

            output_best = 'Best Prec@1: %.3f\n' % (best_prec1)
            print(output_best)
            log_training.write(output_best + '\n')
            log_training.flush()

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
def main():
        

    # options
    parser = argparse.ArgumentParser(description="TSM testing on the full validation set")
    parser.add_argument('dataset', type=str)

    # may contain splits
    parser.add_argument('--weights', type=str, default=None)
    parser.add_argument('--test_segments', type=str, default=25)
    parser.add_argument('--dense_sample', default=False, action="store_true", help='use dense sample as I3D')
    parser.add_argument('--twice_sample', default=False, action="store_true", help='use twice sample for ensemble')
    parser.add_argument('--full_res', default=False, action="store_true",
                        help='use full resolution 256x256 for test as in Non-local I3D')

    parser.add_argument('--test_crops', type=int, default=1)
    parser.add_argument('--coeff', type=str, default=None)
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
                        help='number of data loading workers (default: 8)')

    # for true test
    parser.add_argument('--test_list', type=str, default=None)
    parser.add_argument('--csv_file', type=str, default=None)

    parser.add_argument('--softmax', default=False, action="store_true", help='use softmax')

    parser.add_argument('--max_num', type=int, default=-1)
    parser.add_argument('--input_size', type=int, default=224)
    parser.add_argument('--crop_fusion_type', type=str, default='avg')
    parser.add_argument('--gpus', nargs='+', type=int, default=None)
    parser.add_argument('--img_feature_dim',type=int, default=256)
    parser.add_argument('--num_set_segments',type=int, default=1,help='TODO: select multiply set of n-frames from a video')
    parser.add_argument('--pretrain', type=str, default='imagenet')

    args = parser.parse_args()





    def accuracy(output, target, topk=(1,)):
        """Computes the precision@k for the specified values of k"""
        maxk = max(topk)
        batch_size = target.size(0)
        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


    def parse_shift_option_from_log_name(log_name):
        if 'shift' in log_name:
            strings = log_name.split('_')
            for i, s in enumerate(strings):
                if 'shift' in s:
                    break
            return True, int(strings[i].replace('shift', '')), strings[i + 1]
        else:
            return False, None, None


    weights_list = args.weights.split(',')
    test_segments_list = [int(s) for s in args.test_segments.split(',')]
    assert len(weights_list) == len(test_segments_list)
    if args.coeff is None:
        coeff_list = [1] * len(weights_list)
    else:
        coeff_list = [float(c) for c in args.coeff.split(',')]

    if args.test_list is not None:
        test_file_list = args.test_list.split(',')
    else:
        test_file_list = [None] * len(weights_list)


    data_iter_list = []
    net_list = []
    modality_list = []

    total_num = None
    for this_weights, this_test_segments, test_file in zip(weights_list, test_segments_list, test_file_list):
        is_shift, shift_div, shift_place = parse_shift_option_from_log_name(this_weights)
        if 'RGB' in this_weights:
            modality = 'RGB'
        else:
            modality = 'Flow'
        this_arch = this_weights.split('TSM_')[1].split('_')[2]
        modality_list.append(modality)
        num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset(args.dataset,
                                                                                                modality)
        print('=> shift: {}, shift_div: {}, shift_place: {}'.format(is_shift, shift_div, shift_place))
        net = TSN(num_class, this_test_segments if is_shift else 1, modality,
                base_model=this_arch,
                consensus_type=args.crop_fusion_type,
                img_feature_dim=args.img_feature_dim,
                pretrain=args.pretrain,
                is_shift=is_shift, shift_div=shift_div, shift_place=shift_place,
                non_local='_nl' in this_weights,
                )

        if 'tpool' in this_weights:
            from ops.temporal_shift import make_temporal_pool
            make_temporal_pool(net.base_model, this_test_segments)  # since DataParallel

        checkpoint = torch.load(this_weights)
        checkpoint = checkpoint['state_dict']

        # base_dict = {('base_model.' + k).replace('base_model.fc', 'new_fc'): v for k, v in list(checkpoint.items())}
        base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())}
        replace_dict = {'base_model.classifier.weight': 'new_fc.weight',
                        'base_model.classifier.bias': 'new_fc.bias',
                        }
        for k, v in replace_dict.items():
            if k in base_dict:
                base_dict[v] = base_dict.pop(k)

        net.load_state_dict(base_dict)

        input_size = net.scale_size if args.full_res else net.input_size
        if args.test_crops == 1:
            cropping = torchvision.transforms.Compose([
                GroupScale(net.scale_size),
                GroupCenterCrop(input_size),
            ])
        elif args.test_crops == 3:  # do not flip, so only 5 crops
            cropping = torchvision.transforms.Compose([
                GroupFullResSample(input_size, net.scale_size, flip=False)
            ])
        elif args.test_crops == 5:  # do not flip, so only 5 crops
            cropping = torchvision.transforms.Compose([
                GroupOverSample(input_size, net.scale_size, flip=False)
            ])
        elif args.test_crops == 10:
            cropping = torchvision.transforms.Compose([
                GroupOverSample(input_size, net.scale_size)
            ])
        else:
            raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(args.test_crops))

        data_loader = torch.utils.data.DataLoader(
                TSNDataSet(root_path, test_file if test_file is not None else val_list, num_segments=this_test_segments,
                        new_length=1 if modality == "RGB" else 5,
                        modality=modality,
                        image_tmpl=prefix,
                        test_mode=True,
                        remove_missing=len(weights_list) == 1,
                        transform=torchvision.transforms.Compose([
                            cropping,
                            Stack(roll=(this_arch in ['BNInception', 'InceptionV3'])),
                            ToTorchFormatTensor(div=(this_arch not in ['BNInception', 'InceptionV3'])),
                            GroupNormalize(net.input_mean, net.input_std),
                        ]), dense_sample=args.dense_sample, twice_sample=args.twice_sample),
                batch_size=args.batch_size, shuffle=False,
                num_workers=args.workers, pin_memory=True,
        )

        if args.gpus is not None:
            devices = [args.gpus[i] for i in range(args.workers)]
        else:
            devices = list(range(args.workers))

        net = torch.nn.DataParallel(net.cuda())
        net.eval()

        data_gen = enumerate(data_loader)

        if total_num is None:
            total_num = len(data_loader.dataset)
        else:
            assert total_num == len(data_loader.dataset)

        data_iter_list.append(data_gen)
        net_list.append(net)


    output = []


    def eval_video(video_data, net, this_test_segments, modality):
        net.eval()
        with torch.no_grad():
            i, data, label = video_data
            batch_size = label.numel()
            num_crop = args.test_crops
            if args.dense_sample:
                num_crop *= 10  # 10 clips for testing when using dense sample

            if args.twice_sample:
                num_crop *= 2

            if modality == 'RGB':
                length = 3
            elif modality == 'Flow':
                length = 10
            elif modality == 'RGBDiff':
                length = 18
            else:
                raise ValueError("Unknown modality "+ modality)

            data_in = data.view(-1, length, data.size(2), data.size(3))
            if is_shift:
                data_in = data_in.view(batch_size * num_crop, this_test_segments, length, data_in.size(2), data_in.size(3))
            rst = net(data_in)
            rst = rst.reshape(batch_size, num_crop, -1).mean(1)

            if args.softmax:
                # take the softmax to normalize the output to probability
                rst = F.softmax(rst, dim=1)

            rst = rst.data.cpu().numpy().copy()

            if net.module.is_shift:
                rst = rst.reshape(batch_size, num_class)
            else:
                rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class))

            return i, rst, label


    proc_start_time = time.time()
    max_num = args.max_num if args.max_num > 0 else total_num

    top1 = AverageMeter()
    top5 = AverageMeter()

    for i, data_label_pairs in enumerate(zip(*data_iter_list)):
        with torch.no_grad():
            if i >= max_num:
                break
            this_rst_list = []
            this_label = None
            for n_seg, (_, (data, label)), net, modality in zip(test_segments_list, data_label_pairs, net_list, modality_list):
                rst = eval_video((i, data, label), net, n_seg, modality)
                this_rst_list.append(rst[1])
                this_label = label
            assert len(this_rst_list) == len(coeff_list)
            for i_coeff in range(len(this_rst_list)):
                this_rst_list[i_coeff] *= coeff_list[i_coeff]
            ensembled_predict = sum(this_rst_list) / len(this_rst_list)

            for p, g in zip(ensembled_predict, this_label.cpu().numpy()):
                output.append([p[None, ...], g])
            cnt_time = time.time() - proc_start_time
            prec1, prec5 = accuracy(torch.from_numpy(ensembled_predict), this_label, topk=(1, 5))
            top1.update(prec1.item(), this_label.numel())
            top5.update(prec5.item(), this_label.numel())
            if i % 20 == 0:
                print('video {} done, total {}/{}, average {:.3f} sec/video, '
                    'moving Prec@1 {:.3f} Prec@5 {:.3f}'.format(i * args.batch_size, i * args.batch_size, total_num,
                                                                float(cnt_time) / (i+1) / args.batch_size, top1.avg, top5.avg))

    video_pred = [np.argmax(x[0]) for x in output]
    video_pred_top5 = [np.argsort(np.mean(x[0], axis=0).reshape(-1))[::-1][:5] for x in output]

    video_labels = [x[1] for x in output]


    if args.csv_file is not None:
        print('=> Writing result to csv file: {}'.format(args.csv_file))
        with open(test_file_list[0].replace('test_videofolder.txt', 'category.txt')) as f:
            categories = f.readlines()
        categories = [f.strip() for f in categories]
        with open(test_file_list[0]) as f:
            vid_names = f.readlines()
        vid_names = [n.split(' ')[0] for n in vid_names]
        print(vid_names)
        assert len(vid_names) == len(video_pred)
        if args.dataset != 'somethingv2':  # only output top1
            with open(args.csv_file, 'w') as f:
                for n, pred in zip(vid_names, video_pred):
                    f.write('{};{}\n'.format(n, categories[pred]))
        else:
            with open(args.csv_file, 'w') as f:
                for n, pred5 in zip(vid_names, video_pred_top5):
                    fill = [n]
                    for p in list(pred5):
                        fill.append(p)
                    f.write('{};{};{};{};{};{}\n'.format(*fill))


    cf = confusion_matrix(video_labels, video_pred).astype(float)

    np.save('cm.npy', cf)
    cls_cnt = cf.sum(axis=1)
    cls_hit = np.diag(cf)

    cls_acc = cls_hit / cls_cnt
    print(cls_acc)
    upper = np.mean(np.max(cf, axis=1) / cls_cnt)
    print('upper bound: {}'.format(upper))

    print('-----Evaluation is finished------')
    print('Class Accuracy {:.02f}%'.format(np.mean(cls_acc) * 100))
    print('Overall Prec@1 {:.02f}% Prec@5 {:.02f}%'.format(top1.avg, top5.avg))
Ejemplo n.º 4
0
def main():
    global args, best_prec1
    global crop_size
    args = parser.parse_args()

    num_class, train_list, val_list, args.root_path, prefix = dataset_config.return_dataset(
        args.dataset, args.modality)
    num_class = 1
    if args.train_list == "":
        args.train_list = train_list
    if args.val_list == "":
        args.val_list = val_list

    full_arch_name = args.arch
    if args.shift:
        full_arch_name += '_shift{}_{}'.format(args.shift_div,
                                               args.shift_place)
    if args.concat != "":
        full_arch_name += '_concat{}'.format(args.concat)
    if args.temporal_pool:
        full_arch_name += '_tpool'
    args.store_name += '_'.join([
        'TSM', args.dataset, args.modality, full_arch_name,
        args.consensus_type,
        'lr%.5f' % args.lr,
        'dropout%.2f' % args.dropout,
        'wd%.5f' % args.weight_decay,
        'batch%d' % args.batch_size,
        'segment%d' % args.num_segments, 'e{}'.format(args.epochs)
    ])
    if args.data_fuse:
        args.store_name += '_fuse'
    if args.extra_temporal_modeling:
        args.store_name += '_extra'
    if args.tune_from is not None:
        args.store_name += '_finetune'
    if args.pretrain != 'imagenet':
        args.store_name += '_{}'.format(args.pretrain)
    if args.lr_type != 'step':
        args.store_name += '_{}'.format(args.lr_type)
    if args.dense_sample:
        args.store_name += '_dense'
    if args.non_local > 0:
        args.store_name += '_nl'
    if args.clipnums:
        args.store_name += "_clip{}".format(args.clipnums)
    if args.suffix is not None:
        args.store_name += '_{}'.format(args.suffix)
    print('storing name: ' + args.store_name)

    check_rootfolders()

    if args.prune in ['input', 'inout'] and args.tune_from:
        sd = torch.load(args.tune_from)
        sd = sd['state_dict']
        sd = input_dim_L2distance(sd, args.shift_div)

    model = TSN(
        num_class,
        args.num_segments,
        args.modality,
        base_model=args.arch,
        new_length=2 if args.data_fuse else None,
        consensus_type=args.consensus_type,
        dropout=args.dropout,
        img_feature_dim=args.img_feature_dim,
        partial_bn=not args.no_partialbn,
        pretrain=args.pretrain,
        is_shift=args.shift,
        shift_div=args.shift_div,
        shift_place=args.shift_place,
        fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
        temporal_pool=args.temporal_pool,
        non_local=args.non_local,
        concat=args.concat,
        extra_temporal_modeling=args.extra_temporal_modeling,
        prune_list=[prune_conv1in_list, prune_conv1out_list],
        is_prune=args.prune,
    )

    print(model)
    #summary(model, torch.zeros((16, 24, 224, 224)))
    #exit(1)
    if args.dataset == 'ucf101':  #twice sample & full resolution
        twice_sample = True
        crop_size = model.scale_size  #256 x 256
    else:
        twice_sample = False
        crop_size = model.crop_size  #224 x 224
    crop_size = 256
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies(args.concat)
    #print(type(policies))
    #print(policies)
    #exit()
    train_augmentation = model.get_augmentation(
        flip=False if 'something' in args.dataset or 'jester' in args.dataset
        or 'nvgesture' in args.dataset else True)

    model = torch.nn.DataParallel(model).cuda()

    if args.resume:
        if args.temporal_pool:  # early temporal pool so that we can load the state_dict
            make_temporal_pool(model.module.base_model, args.num_segments)
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    if args.tune_from:
        print(("=> fine-tuning from '{}'".format(args.tune_from)))
        tune_from_list = args.tune_from.split(',')
        sd = torch.load(tune_from_list[0])
        sd = sd['state_dict']

        model_dict = model.state_dict()
        replace_dict = []
        for k, v in sd.items():
            if k not in model_dict and k.replace('.net', '') in model_dict:
                print('=> Load after remove .net: ', k)
                replace_dict.append((k, k.replace('.net', '')))
        for k, v in model_dict.items():
            if k not in sd and k.replace('.net', '') in sd:
                print('=> Load after adding .net: ', k)
                replace_dict.append((k.replace('.net', ''), k))
        for k, v in model_dict.items():
            if k not in sd and k.replace('.prune', '') in sd:
                print('=> Load after adding .prune: ', k)
                replace_dict.append((k.replace('.prune', ''), k))

        if args.prune in ['input', 'inout']:
            sd = adjust_para_shape_prunein(sd, model_dict)
        if args.prune in ['output', 'inout']:
            sd = adjust_para_shape_pruneout(sd, model_dict)

        if args.concat != "" and "concat" not in tune_from_list[0]:
            sd = adjust_para_shape_concat(sd, model_dict)

        for k, k_new in replace_dict:
            sd[k_new] = sd.pop(k)
        keys1 = set(list(sd.keys()))
        keys2 = set(list(model_dict.keys()))
        set_diff = (keys1 - keys2) | (keys2 - keys1)
        print('#### Notice: keys that failed to load: {}'.format(set_diff))
        if args.dataset not in tune_from_list[0]:  # new dataset
            print('=> New dataset, do not load fc weights')
            sd = {k: v for k, v in sd.items() if 'fc' not in k}
        if args.modality != 'Flow' and 'Flow' in tune_from_list[0]:
            sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k}
        #print(sd.keys())
        #print("*"*50)
        #print(model_dict.keys())
        for k, v in list(sd.items()):
            if k not in model_dict:
                sd.pop(k)
        sd.pop("module.base_model.embedding.weight")

        model_dict.update(sd)
        model.load_state_dict(model_dict)

    if args.temporal_pool and not args.resume:
        make_temporal_pool(model.module.base_model, args.num_segments)

    decoder = TransformerModel().cuda()
    if args.decoder_resume:
        decoder_chkpoint = torch.load(args.decoder_resume)

        decoder.load_state_dict(decoder_chkpoint["state_dict"])
    print("decoder parameters = ", decoder.parameters())
    policies.append({
        "params": decoder.parameters(),
        "lr_mult": 5,
        "decay_mult": 1,
        "name": "Attndecoder_weight"
    })
    cudnn.benchmark = True
    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality in ['RGB']:
        data_length = 1
    elif args.modality in ['Depth']:
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5
    '''
	dataRoot = r"/home/share/YouCook/downloadVideo"
	for dirPath, dirnames, filenames in os.walk(dataRoot):
		for filename in filenames:
			print(os.path.join(dirPath,filename) +"is {}".format("exist" if os.path.isfile(os.path.join(dirPath,filename))else "NON"))
			train_data = torchvision.io.read_video(os.path.join(dirPath,filename),start_pts=0,end_pts=1001, )
			tmp = torchvision.io.read_video_timestamps(os.path.join(dirPath,filename),)
			print(tmp)
			print(len(tmp[0]))
			print(train_data[0].size())
			exit()
	exit()
	'''
    '''
	train_loader = torch.utils.data.DataLoader(
		TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments,
				   new_length=data_length,
				   modality=args.modality,
				   image_tmpl=prefix,
				   transform=torchvision.transforms.Compose([
					   train_augmentation,
					   Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
					   ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
					   normalize,
				   ]), dense_sample=args.dense_sample, data_fuse = args.data_fuse),
		batch_size=args.batch_size, shuffle=True,
		num_workers=args.workers, pin_memory=True,
		drop_last=True)  # prevent something not % n_GPU
	

	
	val_loader = torch.utils.data.DataLoader(
		TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments,
				   new_length=data_length,
				   modality=args.modality,
				   image_tmpl=prefix,
				   random_shift=False,
				   transform=torchvision.transforms.Compose([
					   GroupScale(int(scale_size)),
					   GroupCenterCrop(crop_size),
					   Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
					   ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
					   normalize,
				   ]), dense_sample=args.dense_sample, twice_sample=twice_sample, data_fuse = args.data_fuse),
		batch_size=args.batch_size, shuffle=False,
		num_workers=args.workers, pin_memory=True)
	'''
    #global trainDataloader, valDataloader, train_loader, val_loader
    trainDataloader = YouCookDataSetRcg(args.root_path, args.train_list,train=True,inputsize=crop_size,hasPreprocess = False,\
      clipnums=args.clipnums,
      hasWordIndex = True,)
    valDataloader = YouCookDataSetRcg(args.root_path, args.val_list,val=True,inputsize=crop_size,hasPreprocess = False,\
			#clipnums=args.clipnums,

      hasWordIndex = True,)

    #print(trainDataloader._getMode())
    #print(valDataloader._getMode())
    #exit()
    train_loader = torch.utils.data.DataLoader(trainDataloader,
                                               #shuffle=True,
                                               )
    val_loader = torch.utils.data.DataLoader(valDataloader)
    index2wordDict = trainDataloader.getIndex2wordDict()
    #print(train_loader is val_loader)
    #print(trainDataloader._getMode())
    #print(valDataloader._getMode())

    #print(trainDataloader._getMode())
    #print(valDataloader._getMode())
    #print(len(train_loader))

    #exit()
    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.NLLLoss().cuda()
    elif args.loss_type == "MSELoss":
        criterion = torch.nn.MSELoss().cuda()
    elif args.loss_type == "BCELoss":
        #print("BCELoss")
        criterion = torch.nn.BCELoss().cuda()
    elif args.loss_type == "CrossEntropyLoss":
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    #print(os.path.join(args.root_log, args.store_name, 'args.txt'))
    #exit()
    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)
        #print("265")
        # train for one epoch
        ######
        #print(trainDataloader._getMode())
        #print(valDataloader._getMode())
        train(train_loader, model, decoder, criterion, optimizer, epoch,
              log_training, tf_writer, index2wordDict)
        ######
        #print("268")
        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader,
                             model,
                             decoder,
                             criterion,
                             epoch,
                             log_training,
                             tf_writer,
                             index2wordDict=index2wordDict)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch)

            output_best = 'Best Prec@1: %.3f\n' % (best_prec1)
            #print(output_best)
            log_training.write(output_best + '\n')
            log_training.flush()

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': decoder.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': best_prec1,
                },
                is_best,
                filename="decoder")
        else:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': best_prec1,
                }, False)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': decoder.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': best_prec1,
                },
                is_best,
                filename="decoder")
        #break
        print("test pass")
Ejemplo n.º 5
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(
        args.dataset, args.modality)
    full_arch_name = args.arch
    if args.temporal_pool:
        full_arch_name += '_tpool'
    args.store_name = '_'.join([
        'I3D', args.dataset, full_arch_name, 'batch{}'.format(args.batch_size),
        'wd{}'.format(args.weight_decay), args.consensus_type,
        'segment%d' % args.num_segments, 'e{}'.format(args.epochs),
        'dropout{}'.format(args.dropout), args.pretrain,
        'lr{}'.format(args.lr), '_warmup{}'.format(args.warmup)
    ])
    if args.lr_type != 'step':
        args.store_name += '_{}'.format(args.lr_type)
    else:
        step_str = [str(int(x)) for x in args.lr_steps]
        args.store_name += '_step' + '_'.join(step_str)
    if args.dense_sample:
        args.store_name += '_dense'
    if args.spatial_dropout:
        sigmoid_layer_str = '_'.join(args.sigmoid_layer)
        args.store_name += '_spatial_drop3d_{}_group{}_layer{}'.format(
            args.sigmoid_thres, args.sigmoid_group, sigmoid_layer_str)
        if args.sigmoid_random:
            args.store_name += '_RandomSigmoid'
    if args.non_local > 0:
        args.store_name += '_nl'
    if args.suffix is not None:
        args.store_name += '_{}'.format(args.suffix)
    print('storing name: ' + args.store_name)

    check_rootfolders()

    model = i3d(num_class,
                args.num_segments,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                spatial_dropout=args.spatial_dropout,
                sigmoid_thres=args.sigmoid_thres,
                sigmoid_group=args.sigmoid_group,
                sigmoid_random=args.sigmoid_random,
                sigmoid_layer=args.sigmoid_layer,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    train_augmentation = model.get_augmentation(
        flip=False
        if 'something' in args.dataset or 'jester' in args.dataset else True)

    model = torch.nn.DataParallel(model,
                                  device_ids=list(range(args.gpus))).cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.resume:
        if args.temporal_pool:  # early temporal pool so that we can load the state_dict
            make_temporal_pool(model.module.base_model, args.num_segments)
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    if args.tune_from:
        print(("=> fine-tuning from '{}'".format(args.tune_from)))
        sd = torch.load(args.tune_from)
        sd = sd['state_dict']
        model_dict = model.state_dict()
        replace_dict = []
        for k, v in sd.items():
            if k not in model_dict and k.replace('.net', '') in model_dict:
                print('=> Load after remove .net: ', k)
                replace_dict.append((k, k.replace('.net', '')))
        for k, v in model_dict.items():
            if k not in sd and k.replace('.net', '') in sd:
                print('=> Load after adding .net: ', k)
                replace_dict.append((k.replace('.net', ''), k))

        for k, k_new in replace_dict:
            sd[k_new] = sd.pop(k)
        keys1 = set(list(sd.keys()))
        keys2 = set(list(model_dict.keys()))
        set_diff = (keys1 - keys2) | (keys2 - keys1)
        print('#### Notice: keys that failed to load: {}'.format(set_diff))
        if args.dataset not in args.tune_from:  # new dataset
            print('=> New dataset, do not load fc weights')
            sd = {k: v for k, v in sd.items() if 'fc' not in k}
        if args.modality == 'Flow' and 'Flow' not in args.tune_from:
            sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k}
        model_dict.update(sd)
        model.load_state_dict(model_dict)

    if args.temporal_pool and not args.resume:
        make_temporal_pool(model.module.base_model, args.num_segments)

    cudnn.benchmark = True

    # Data loading code
    normalize = GroupNormalize(input_mean, input_std)

    train_loader = torch.utils.data.DataLoader(
        TSNDataSet(args.root_path,
                   args.train_list,
                   num_segments=args.num_segments,
                   image_tmpl=prefix,
                   transform=torchvision.transforms.Compose([
                       GroupScale((256, 340)),
                       train_augmentation,
                       Stack('3D'),
                       ToTorchFormatTensor(),
                       normalize,
                   ]),
                   dense_sample=args.dense_sample),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=True)  # prevent something not % n_GPU

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.val_list,
        num_segments=args.num_segments,
        image_tmpl=prefix,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack('3D'),
            ToTorchFormatTensor(),
            normalize,
        ]),
        dense_sample=args.dense_sample),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.BCEWithLogitsLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))
    for epoch in range(args.start_epoch, args.epochs):
        # adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training,
              tf_writer, args.warmup, args.lr_type, args.lr_steps)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion, epoch, log_training,
                             tf_writer)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            tf_writer.add_scalar('acc/test_mAP_best', best_prec1, epoch)

            output_best = 'Best mAP: %.3f\n' % (best_prec1)
            print(output_best)
            log_training.write(output_best + '\n')
            log_training.flush()

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
Ejemplo n.º 6
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    num_class = opts.num_class

    full_arch_name = args.arch
    if args.shift:
        full_arch_name += '_shift{}_{}'.format(args.shift_div,
                                               args.shift_place)
    if args.temporal_pool:
        full_arch_name += '_tpool'
    args.store_name = '_'.join([
        'TSM', args.dataset, args.modality, full_arch_name,
        args.consensus_type,
        'segment%d' % args.num_segments, 'e{}'.format(args.epochs)
    ])
    if args.pretrain != 'imagenet':
        args.store_name += '_{}'.format(args.pretrain)
    if args.lr_type != 'step':
        args.store_name += '_{}'.format(args.lr_type)
    if args.non_local > 0:
        args.store_name += '_nl'
    print('storing name: ' + args.store_name)

    # check_rootfolders()

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_size=args.img_size,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                is_shift=args.shift,
                shift_div=args.shift_div,
                shift_place=args.shift_place,
                fc_lr5=True,
                temporal_pool=args.temporal_pool,
                non_local=args.non_local)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation(
        flip=False
        if 'something' in args.dataset or 'jester' in args.dataset else True)

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    model.apply(weights_init)

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.resume:
        if args.temporal_pool:  # early temporal pool so that we can load the state_dict
            make_temporal_pool(model.module.base_model, args.num_segments)
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            # best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'], strict=False)
            # optimizer.load_state_dict(checkpoint['optimizer'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    if args.temporal_pool and not args.resume:
        make_temporal_pool(model.module.base_model, args.num_segments)

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    fr_r = open(opts.NUM_LABEL_R, 'r+')
    w2n = eval(fr_r.read())
    fr_r.close()

    fr = open(opts.NUM_LABEL, 'r+')
    n2w = eval(fr.read())
    fr.close()

    train_loader, val_loader, test_loader = None, None, None
    if args.mode != 'test':
        lip_dict, video_list = opts.file_deal(opts.TRAIN_DATA, w2n)
        train_num = int(len(video_list) * 0.95)
        train_loader = torch.utils.data.DataLoader(TSNDataSet(
            opts.TRAIN_DATA,
            args.mode,
            num_segments=args.num_segments,
            img_size=args.img_size,
            lip_dict=lip_dict,
            video_list=video_list[:train_num],
            transform=torchvision.transforms.Compose([
                train_augmentation,
                GroupMultiScaleCrop(args.img_size, [1, .875, .75, .66]),
                Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                ToTorchFormatTensor(
                    div=(args.arch not in ['BNInception', 'InceptionV3'])),
                normalize,
            ])),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True)

        val_loader = torch.utils.data.DataLoader(TSNDataSet(
            opts.TRAIN_DATA,
            args.mode,
            num_segments=args.num_segments,
            img_size=args.img_size,
            lip_dict=lip_dict,
            video_list=video_list[train_num:],
            transform=torchvision.transforms.Compose([
                GroupScale(int(scale_size)),
                GroupCenterCrop(crop_size),
                Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                ToTorchFormatTensor(
                    div=(args.arch not in ['BNInception', 'InceptionV3'])),
                normalize,
            ])),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
    else:
        lip_dict, video_list = opts.file_deal(opts.TEST_DATA, w2n)
        test_loader = torch.utils.data.DataLoader(TSNDataSet_infer(
            opts.TEST_DATA,
            num_segments=args.num_segments,
            img_size=args.img_size,
            lip_dict=lip_dict,
            video_list=video_list,
            transform=torchvision.transforms.Compose([
                GroupScale(int(scale_size)),
                GroupCenterCrop(crop_size),
                Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                ToTorchFormatTensor(
                    div=(args.arch not in ['BNInception', 'InceptionV3'])),
                normalize,
            ])),
                                                  batch_size=args.batch_size,
                                                  shuffle=False,
                                                  num_workers=args.workers,
                                                  pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    if args.mode == 'test':
        if args.sub == 'sub':
            inference(test_loader, model, n2w)
        else:
            inferencefusion(test_loader, model, n2w)
        return

    # 开始训练
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion, epoch, n2w)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)

            output_best = 'Best Prec@1: %.3f\n' % (best_prec1)
            print(output_best)

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
Ejemplo n.º 7
0
def v_train(train_loader, val_loader, model, num_class, vnet, criterion,
            valcriterion, optimizer, epoch, log, tf_writer):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    val_loader_iter = iter(val_loader)

    if args.no_partialbn:
        model.module.partialBN(False)
    else:
        model.module.partialBN(True)

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        target = target.cuda()
        input_var = torch.autograd.Variable(input)
        target_var = torch.autograd.Variable(target)

        vnet_temp = VNet(1, 100, 1).cuda()
        optimizer_vnet_temp = torch.optim.Adam(vnet_temp.params(),
                                               1e-3,
                                               weight_decay=1e-4)
        vnet_temp.load_state_dict(vnet.state_dict())

        v_model = v_TSN(
            num_class,
            args.num_segments,
            args.modality,
            base_model=args.arch,
            consensus_type=args.consensus_type,
            dropout=args.dropout,
            img_feature_dim=args.img_feature_dim,
            partial_bn=not args.no_partialbn,
            pretrain=args.pretrain,
            is_shift=args.shift,
            shift_div=args.shift_div,
            shift_place=args.shift_place,
            fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
            temporal_pool=args.temporal_pool,
            non_local=args.non_local,
            print_spec=False)
        v_model = torch.nn.DataParallel(v_model, device_ids=args.gpus).cuda()
        if args.temporal_pool and not args.resume:
            make_temporal_pool(v_model.module.base_model, args.num_segments)
        v_model.load_state_dict(model.state_dict())

        # compute output
        output = v_model(input_var)
        # loss = criterion(output, target_var)
        cost = criterion(output, target_var)
        cost_v = torch.reshape(cost, (-1, 1))
        v_lambda = vnet_temp(cost_v.data)
        l_f_v = torch.sum(cost_v * v_lambda) / len(cost_v)
        v_model.zero_grad()
        grads = torch.autograd.grad(l_f_v, (v_model.module.params()),
                                    create_graph=True)
        # to be modified
        v_lr = args.lr * ((0.1**int(epoch >= 80)) * (0.1**int(epoch >= 100)))
        v_model.module.update_params(lr_inner=v_lr, source_params=grads)
        del grads

        # phase 2. pixel weights step
        try:
            inputs_val, targets_val = next(val_loader_iter)  # 拿一个val set图片
        except StopIteration:
            val_loader_iter = iter(val_loader)
            inputs_val, targets_val = next(val_loader_iter)
        # inputs_val, targets_val = sample_val['image'], sample_val['label']
        inputs_val, targets_val = inputs_val.cuda(), targets_val.cuda()
        y_g_hat = v_model(inputs_val)
        l_g_meta = valcriterion(y_g_hat, targets_val)  # val loss
        optimizer_vnet_temp.zero_grad()
        l_g_meta.backward()
        optimizer_vnet_temp.step()
        vnet.load_state_dict(vnet_temp.state_dict())

        # phase 1. network weight step (w)
        output = model(input_var)
        cost = criterion(output, target)
        cost_v = torch.reshape(cost, (-1, 1))
        with torch.no_grad():
            v_new = vnet(cost_v)
        loss = torch.sum(cost_v * v_new) / len(cost_v)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
        top5.update(prec5.item(), input.size(0))

        # compute gradient and do SGD step
        # loss.backward()

        if args.clip_gradient is not None:
            total_norm = clip_grad_norm_(model.parameters(),
                                         args.clip_gradient)

        # optimizer.step()
        # optimizer.zero_grad()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            output = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                          epoch,
                          i,
                          len(train_loader),
                          batch_time=batch_time,
                          data_time=data_time,
                          loss=losses,
                          top1=top1,
                          top5=top5,
                          lr=optimizer.param_groups[-1]['lr'] * 0.1))  # TODO
            print(output, end=" ")
            for n, p in vnet.named_params(vnet):
                print("vnet param: ", n, p[0].item())
                break
            log.write(output + '\n')
            log.flush()

    tf_writer.add_scalar('loss/train', losses.avg, epoch)
    tf_writer.add_scalar('acc/train_top1', top1.avg, epoch)
    tf_writer.add_scalar('acc/train_top5', top5.avg, epoch)
    tf_writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], epoch)
    def __init__(self, weightPath, segments, crops, fullSize=True):
        self.weightPath = weightPath
        self.segments = segments
        self.crops = crops
        self.fullSize = fullSize

        self.is_shift, shift_div, shift_place = parse_shift_option_from_log_name(self.weightPath)
        if 'RGB' in self.weightPath:
            self.modality = 'RGB'
        else:
            self.modality = 'Flow'
        this_arch = self.weightPath.split('TSM_')[1].split('_')[2]

        self.num_class = 400
        print('=> shift: {}, shift_div: {}, shift_place: {}'.format(self.is_shift, shift_div, shift_place))
        self.net = TSN(self.num_class, self.segments if self.is_shift else 1, self.modality,
                       base_model=this_arch,
                       consensus_type="avg",
                       img_feature_dim=256,
                       pretrain="imagenet",
                       is_shift=self.is_shift, shift_div=shift_div, shift_place=shift_place,
                       non_local='_nl' in self.weightPath,
                       )

        if 'tpool' in self.weightPath:
            from ops.temporal_shift import make_temporal_pool
            make_temporal_pool(self.net.base_model, self.segments)  # since DataParallel

        checkpoint = torch.load(self.weightPath)
        checkpoint = checkpoint['state_dict']

        base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())}
        replace_dict = {'base_model.classifier.weight': 'new_fc.weight',
                        'base_model.classifier.bias': 'new_fc.bias',
                        }

        for k, v in replace_dict.items():
            if k in base_dict:
                base_dict[v] = base_dict.pop(k)

        self.net.load_state_dict(base_dict)

        input_size = self.net.scale_size if self.fullSize else self.net.input_size
        if self.crops == 1:
            cropping = torchvision.transforms.Compose([
                GroupScale(self.net.scale_size),
                GroupCenterCrop(input_size),
            ])
        elif self.crops == 3:  # do not flip, so only 5 crops
            cropping = torchvision.transforms.Compose([
                GroupFullResSample(input_size, self.net.scale_size, flip=False)
            ])
        elif self.crops == 5:  # do not flip, so only 5 crops
            cropping = torchvision.transforms.Compose([
                GroupOverSample(input_size, self.net.scale_size, flip=False)
            ])
        elif self.crops == 10:
            cropping = torchvision.transforms.Compose([
                GroupOverSample(input_size, self.net.scale_size)
            ])
        else:
            raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(self.crops))

        self.transform = torchvision.transforms.Compose([
            cropping,
            Stack(roll=(this_arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(div=(this_arch not in ['BNInception', 'InceptionV3'])),
            GroupNormalize(self.net.input_mean, self.net.input_std)])

        if self.modality == 'RGB':
            self.length = 3
        elif self.modality == 'Flow':
            self.length = 10
        elif self.modality == 'RGBDiff':
            self.length = 18

        self.net = self.net.cuda()
        self.net.eval()
Ejemplo n.º 9
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    #num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset,
    #                                                                                                  args.modality)
    num_class = 21
    args.train_list = "/home/jzwang/code/Video_3D/movienet/data/movie/movie_train.txt"
    args.val_list = "/home/jzwang/code/Video_3D/movienet/data/movie/movie_test.txt"
    args.root_path = ""
    prefix = "frame_{:04d}.jpg"
    full_arch_name = args.arch
    if args.shift:
        full_arch_name += '_shift{}_{}'.format(args.shift_div, args.shift_place)
    if args.temporal_pool:
        full_arch_name += '_tpool'
    args.store_name = '_'.join(
        ['TSM', args.dataset, args.modality, full_arch_name, args.consensus_type, 'segment%d' % args.num_segments,
         'e{}'.format(args.epochs)])
    if args.pretrain != 'imagenet':
        args.store_name += '_{}'.format(args.pretrain)
    if args.lr_type != 'step':
        args.store_name += '_{}'.format(args.lr_type)
    if args.dense_sample:
        args.store_name += '_dense'
    if args.non_local > 0:
        args.store_name += '_nl'
    if args.suffix is not None:
        args.store_name += '_{}'.format(args.suffix)
    print('storing name: ' + args.store_name)

    #check_rootfolders()

    model = TSN(num_class, args.num_segments, args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place,
                fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
                temporal_pool=args.temporal_pool,
                non_local=args.non_local)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation(flip=False if 'something' in args.dataset or 'jester' in args.dataset else True)

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.resume:
        if args.temporal_pool:  # early temporal pool so that we can load the state_dict
            make_temporal_pool(model.module.base_model, args.num_segments)
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            #best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(("=> loaded checkpoint '{}' (epoch {})"
                   .format(args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    if args.tune_from:
        print(("=> fine-tuning from '{}'".format(args.tune_from)))
        sd = torch.load(args.tune_from)
        sd = sd['state_dict']
        model_dict = model.state_dict()
        replace_dict = []
        for k, v in sd.items():
            if k not in model_dict and k.replace('.net', '') in model_dict:
                print('=> Load after remove .net: ', k)
                replace_dict.append((k, k.replace('.net', '')))
        for k, v in model_dict.items():
            if k not in sd and k.replace('.net', '') in sd:
                print('=> Load after adding .net: ', k)
                replace_dict.append((k.replace('.net', ''), k))

        for k, k_new in replace_dict:
            sd[k_new] = sd.pop(k)
        keys1 = set(list(sd.keys()))
        keys2 = set(list(model_dict.keys()))
        set_diff = (keys1 - keys2) | (keys2 - keys1)
        print('#### Notice: keys that failed to load: {}'.format(set_diff))
        if args.dataset not in args.tune_from:  # new dataset
            print('=> New dataset, do not load fc weights')
            sd = {k: v for k, v in sd.items() if 'fc' not in k}
        if args.modality == 'Flow' and 'Flow' not in args.tune_from:
            sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k}
        model_dict.update(sd)
        model.load_state_dict(model_dict)

    if args.temporal_pool and not args.resume:
        make_temporal_pool(model.module.base_model, args.num_segments)

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_loader = torch.utils.data.DataLoader(
        TSNDataSetMovie("", args.train_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl="frame_{:04d}.jpg" if args.modality in ["RGB", "RGBDiff"] else args.flow_prefix+"{}_{:05d}.jpg",
                   transform=torchvision.transforms.Compose([
                       train_augmentation,
                       Stack(roll=args.arch == 'BNInception'),
                       ToTorchFormatTensor(div=args.arch != 'BNInception'),
                       normalize,
                   ])),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        TSNDataSetMovie("", args.val_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl="frame_{:04d}.jpg" if args.modality in ["RGB", "RGBDiff"] else args.flow_prefix+"{}_{:05d}.jpg",
                   random_shift=False,
                   transform=torchvision.transforms.Compose([
                       GroupScale(int(scale_size)),
                       GroupCenterCrop(crop_size),
                       Stack(roll=args.arch == 'BNInception'),
                       ToTorchFormatTensor(div=args.arch != 'BNInception'),
                       normalize,
                   ])),
        batch_size=int(args.batch_size/2), shuffle=False,
        num_workers=args.workers, pin_memory=True)

    # define loss function (criterion) and optimizer
    criterion =  torch.nn.BCEWithLogitsLoss().cuda()
    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'], group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    zero_time = time.time()
    best_map = 0
    print ('Start training...')
    for epoch in range(args.start_epoch, args.epochs):
        valloss, mAP, wAP, output_mtx = validate(val_loader, model, criterion)
        adjust_learning_rate(optimizer, epoch, args.lr_steps)
        np.save("testnew.npy", output_mtx)
        print("saving down")
        # train for one epoch
        start_time = time.time()
        trainloss = train(train_loader, model, criterion, optimizer, epoch)

        print('Traing loss %4f Epoch %d'% (trainloss, epoch))
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            valloss, mAP, wAP, output_mtx = validate(val_loader, model, criterion)
            end_time = time.time()
            epoch_time = end_time - start_time
            total_time = end_time - zero_time
            print ('Total time used: %s Epoch %d time uesd: %s'%(
    				str(datetime.timedelta(seconds=int(total_time))),
    				epoch, str(datetime.timedelta(seconds=int(epoch_time)))))
            print ('Train loss: {0:.4f} val loss: {1:.4f} mAP: {2:.4f} wAP: {3:.4f}'.format(
    		   			trainloss, valloss, mAP, wAP))
            # evaluate on validation set
            is_best = mAP > best_map
            #if mAP > best_map:
                #best_map = mAP
    			# checkpoint_name = "%04d_%s" % (epoch+1, "checkpoint.pth.tar")
            checkpoint_name = "best_checkpoint.pth.tar"
            save_checkpoint({
			    'epoch': epoch+1,
			    'state_dict': model.state_dict(),
			    'optimizer': optimizer.state_dict(),
			    }, is_best, epoch)
            np.save("testnew.npy", output_mtx)
            print("saving down")
        with open(args.record_path, 'a') as file:
            file.write('Epoch:[{0}]'
		   		   'Train loss: {1:.4f} val loss: {2:.4f} map: {3:.4f}\n'.format(
		   			epoch+1, trainloss, valloss, mAP))


    print ('************ Done!... ************')
Ejemplo n.º 10
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset,
                                                                                                      args.modality)
    model = TSN(num_class, args.num_segments, args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place,
                fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
                temporal_pool=args.temporal_pool,
                non_local=args.non_local,
                cca3d = args.cca3d
                )

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation(flip=False if 'something' in args.dataset or 'jester' in args.dataset else True)

    model = model.cuda()

    if args.resume:
        if args.temporal_pool:  # early temporal pool so that we can load the state_dict
            make_temporal_pool(model.module.base_model, args.num_segments)
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume,map_location='cuda:0')
            parallel_state_dict = checkpoint['state_dict']
            cpu_state_dict={}
            for k,v in parallel_state_dict.items():
                cpu_state_dict[k[len('module.'):]] = v
            model.load_state_dict(cpu_state_dict)
            print(("=> loaded checkpoint '{}' (epoch {})".format(args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5
    preprocess=torchvision.transforms.Compose([
                       GroupScale(int(scale_size)),
                       GroupCenterCrop(crop_size),
                       Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                       ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
                       normalize,
                   ])
    val_loader = torch.utils.data.DataLoader(
        TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl=prefix,
                   random_shift=False,
                   transform = preprocess,
                   dense_sample=args.dense_sample),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=False)
    norm_param = (input_mean, input_std)
    cam_process(val_loader, model,norm_param)
Ejemplo n.º 11
0
def load_model(weights):
    global num_class
    is_shift, shift_div, shift_place = parse_shift_option_from_log_name(
        weights)
    if 'RGB' in weights:
        modality = 'RGB'
    elif 'Depth' in weights:
        modality = 'Depth'
    else:
        modality = 'Flow'

    if 'concatAll' in weights:
        concat = "All"
    elif "concatFirst" in weights:
        concat = "First"
    else:
        concat = ""

    if 'extra' in this_weights:
        extra_temporal_modeling = True

    args.prune = ""

    if 'conv1d' in weights:
        args.crop_fusion_type = "conv1d"
    else:
        args.crop_fusion_type = "avg"

    this_arch = weights.split('TSM_')[1].split('_')[2]
    modality_list.append(modality)
    num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset(
        args.dataset, modality)

    print('=> shift: {}, shift_div: {}, shift_place: {}'.format(
        is_shift, shift_div, shift_place))
    net = TSN(num_class,
              int(args.test_segments) if is_shift else 1,
              modality,
              base_model=this_arch,
              consensus_type=args.crop_fusion_type,
              img_feature_dim=args.img_feature_dim,
              pretrain=args.pretrain,
              is_shift=is_shift,
              shift_div=shift_div,
              shift_place=shift_place,
              non_local='_nl' in weights,
              concat=concat,
              extra_temporal_modeling=extra_temporal_modeling,
              prune_list=[prune_conv1in_list, prune_conv1out_list],
              is_prune=args.prune)

    if 'tpool' in weights:
        from ops.temporal_shift import make_temporal_pool
        make_temporal_pool(net.base_model,
                           args.test_segments)  # since DataParallel

    checkpoint = torch.load(weights)
    checkpoint = checkpoint['state_dict']

    # base_dict = {('base_model.' + k).replace('base_model.fc', 'new_fc'): v for k, v in list(checkpoint.items())}
    base_dict = {
        '.'.join(k.split('.')[1:]): v
        for k, v in list(checkpoint.items())
    }
    replace_dict = {
        'base_model.classifier.weight': 'new_fc.weight',
        'base_model.classifier.bias': 'new_fc.bias',
    }
    for k, v in replace_dict.items():
        if k in base_dict:
            base_dict[v] = base_dict.pop(k)

    net.load_state_dict(base_dict)

    input_size = net.scale_size if args.full_res else net.input_size
    if args.test_crops == 1:
        cropping = torchvision.transforms.Compose([
            GroupScale(net.scale_size),
            GroupCenterCrop(input_size),
        ])
    elif args.test_crops == 3:  # do not flip, so only 5 crops
        cropping = torchvision.transforms.Compose(
            [GroupFullResSample(input_size, net.scale_size, flip=False)])
    elif args.test_crops == 5:  # do not flip, so only 5 crops
        cropping = torchvision.transforms.Compose(
            [GroupOverSample(input_size, net.scale_size, flip=False)])
    elif args.test_crops == 10:
        cropping = torchvision.transforms.Compose(
            [GroupOverSample(input_size, net.scale_size)])
    else:
        raise ValueError(
            "Only 1, 5, 10 crops are supported while we got {}".format(
                args.test_crops))

    transform = torchvision.transforms.Compose([
        cropping,
        Stack(roll=(this_arch in ['BNInception', 'InceptionV3'])),
        ToTorchFormatTensor(
            div=(this_arch not in ['BNInception', 'InceptionV3'])),
        GroupNormalize(net.input_mean, net.input_std),
    ])

    if args.gpus is not None:
        devices = [args.gpus[i] for i in range(args.workers)]
    else:
        devices = list(range(args.workers))

    net = torch.nn.DataParallel(net.cuda())
    return is_shift, net, transform