예제 #1
0
def evaluate_model(num_class):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    net = TBN(num_class,
              1,
              args.modality,
              base_model=args.arch,
              consensus_type=args.crop_fusion_type,
              dropout=args.dropout,
              midfusion=args.midfusion)

    weights = '{weights_dir}/model_best.pth.tar'.format(
        weights_dir=args.weights_dir)
    checkpoint = torch.load(weights)
    print("model epoch {} best prec@1: {}".format(checkpoint['epoch'],
                                                  checkpoint['best_prec1']))

    base_dict = {
        '.'.join(k.split('.')[1:]): v
        for k, v in list(checkpoint['state_dict'].items())
    }
    net.load_state_dict(base_dict)

    test_transform = {}
    image_tmpl = {}
    for m in args.modality:
        if m != 'Spec':
            if args.test_crops == 1:
                cropping = torchvision.transforms.Compose([
                    GroupScale(net.scale_size[m]),
                    GroupCenterCrop(net.input_size[m]),
                ])
            elif args.test_crops == 10:
                cropping = torchvision.transforms.Compose(
                    [GroupOverSample(net.input_size[m], net.scale_size[m])])
            else:
                raise ValueError("Only 1 and 10 crops are supported" +
                                 " while we got {}".format(args.test_crops))

            test_transform[m] = torchvision.transforms.Compose([
                cropping,
                Stack(roll=args.arch == 'BNInception'),
                ToTorchFormatTensor(div=args.arch != 'BNInception'),
                GroupNormalize(net.input_mean[m], net.input_std[m]),
            ])

            # Prepare dictionaries containing image name templates
            # for each modality
            if m in ['RGB', 'RGBDiff']:
                image_tmpl[m] = "img_{:010d}.jpg"
            elif m == 'Flow':
                image_tmpl[m] = args.flow_prefix + "{}_{:010d}.jpg"
        else:

            test_transform[m] = torchvision.transforms.Compose([
                Stack(roll=args.arch == 'BNInception'),
                ToTorchFormatTensor(div=False),
            ])

    data_length = net.new_length

    test_loader = torch.utils.data.DataLoader(TBNDataSet(
        args.dataset,
        pd.read_pickle(args.test_list),
        data_length,
        args.modality,
        image_tmpl,
        visual_path=args.visual_path,
        audio_path=args.audio_path,
        num_segments=args.test_segments,
        mode='test',
        transform=test_transform,
        resampling_rate=args.resampling_rate),
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=args.workers * 2)

    net = torch.nn.DataParallel(net, device_ids=args.gpus).to(device)
    with torch.no_grad():
        net.eval()

        results = []
        total_num = len(test_loader.dataset)

        proc_start_time = time.time()
        max_num = args.max_num if args.max_num > 0 else total_num
        for i, (data, label) in enumerate(test_loader):
            if i >= max_num:
                break
            rst = eval_video(data, net, num_class, device)
            if label != -10000:  # label exists
                if 'epic' not in args.dataset:
                    label_ = label.item()
                else:
                    label_ = {k: v.item() for k, v in label.items()}
                results.append((rst, label_))
            else:  # Test set (S1/S2)
                results.append((rst, ))
            cnt_time = time.time() - proc_start_time
            print('video {} done, total {}/{}, average {} sec/video'.format(
                i, i + 1, total_num,
                float(cnt_time) / (i + 1)))

        return results
def main():
    global args, best_prec1, train_list, experiment_dir, best_loss
    args = parser.parse_args()

    if args.dataset == 'ucf101':
        num_class = 101
    elif args.dataset == 'hmdb51':
        num_class = 51
    elif args.dataset == 'kinetics':
        num_class = 400
    elif args.dataset == 'epic':
        num_class = (125, 352)
    else:
        raise ValueError('Unknown dataset ' + args.dataset)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = TBN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                midfusion=args.midfusion)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    data_length = model.new_length
    # policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation()

    # Resume training from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            state_dict_new = OrderedDict()
            for k, v in checkpoint['state_dict'].items():
                state_dict_new[k.split('.', 1)[1]] = v
            model.load_state_dict(state_dict_new)
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    # Load pretrained weights for each stream
    if args.pretrained_flow_weights:
        print('Initialize Flow stream from Kinetics')
        pretrained = os.path.join('pretrained/kinetics_tsn_flow.pth.tar')
        state_dict = torch.load(pretrained)
        for k, v in state_dict.items():
            state_dict[k] = torch.squeeze(v, dim=0)
        base_model = getattr(model, 'flow')
        base_model.load_state_dict(state_dict, strict=False)

    # Freeze stream weights (leaves only fusion and classification trainable)
    if args.freeze:
        model.freeze_fn('modalities')

    # Freeze batch normalisation layers except the first
    if args.partialbn:
        model.freeze_fn('partialbn_parameters')

    model = torch.nn.DataParallel(model, device_ids=args.gpus).to(device)

    cudnn.benchmark = True

    # Data loading code
    normalize = {}
    for m in args.modality:
        if (m != 'Spec'):
            if (m != 'RGBDiff'):
                normalize[m] = GroupNormalize(input_mean[m], input_std[m])
            else:
                normalize[m] = IdentityTransform()

    image_tmpl = {}
    train_transform = {}
    val_transform = {}
    for m in args.modality:
        if (m != 'Spec'):
            # Prepare dictionaries containing image name templates for each modality
            if m in ['RGB', 'RGBDiff']:
                image_tmpl[m] = "img_{:010d}.jpg"
            elif m == 'Flow':
                image_tmpl[m] = args.flow_prefix + "{}_{:010d}.jpg"
            # Prepare train/val dictionaries containing the transformations
            # (augmentation+normalization)
            # for each modality
            train_transform[m] = torchvision.transforms.Compose([
                train_augmentation[m],
                Stack(roll=args.arch == 'BNInception'),
                ToTorchFormatTensor(div=args.arch != 'BNInception'),
                normalize[m],
            ])

            val_transform[m] = torchvision.transforms.Compose([
                GroupScale(int(scale_size[m])),
                GroupCenterCrop(crop_size[m]),
                Stack(roll=args.arch == 'BNInception'),
                ToTorchFormatTensor(div=args.arch != 'BNInception'),
                normalize[m],
            ])
        else:
            # Prepare train/val dictionaries containing the transformations
            # (augmentation+normalization)
            # for each modality
            train_transform[m] = torchvision.transforms.Compose([
                Stack(roll=args.arch == 'BNInception'),
                ToTorchFormatTensor(div=False),
            ])

            val_transform[m] = torchvision.transforms.Compose([
                Stack(roll=args.arch == 'BNInception'),
                ToTorchFormatTensor(div=False),
            ])

    if args.train_list is None:
        # If train_list is not provided, we train on the default
        # dataset which is all the training set
        train_loader = torch.utils.data.DataLoader(TBNDataSet(
            args.dataset,
            training_labels(),
            data_length,
            args.modality,
            image_tmpl,
            visual_path=args.visual_path,
            audio_path=args.audio_path,
            num_segments=args.num_segments,
            transform=train_transform,
            fps=args.fps,
            resampling_rate=args.resampling_rate),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True)
    else:
        train_loader = torch.utils.data.DataLoader(TBNDataSet(
            args.dataset,
            args.train_list,
            data_length,
            args.modality,
            image_tmpl,
            visual_path=args.visual_path,
            audio_path=args.audio_path,
            num_segments=args.num_segments,
            transform=train_transform,
            fps=args.fps,
            resampling_rate=args.resampling_rate),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True)
    if args.train_list is not None:
        # we cannot validate on part of the training set
        # if we use all the training set for training
        val_loader = torch.utils.data.DataLoader(TBNDataSet(
            args.dataset,
            args.val_list,
            data_length,
            args.modality,
            image_tmpl,
            visual_path=args.visual_path,
            audio_path=args.audio_path,
            num_segments=args.num_segments,
            mode='val',
            transform=val_transform,
            fps=args.fps,
            resampling_rate=args.resampling_rate),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    if len(args.modality) > 1:
        param_groups = [
            {
                'params':
                filter(lambda p: p.requires_grad,
                       model.module.rgb.parameters())
            },
            {
                'params':
                filter(lambda p: p.requires_grad,
                       model.module.flow.parameters()),
                'lr':
                0.001
            },
            {
                'params':
                filter(lambda p: p.requires_grad,
                       model.module.spec.parameters())
            },
            {
                'params':
                filter(lambda p: p.requires_grad,
                       model.module.fusion_classification_net.parameters())
            },
        ]
    else:
        param_groups = filter(lambda p: p.requires_grad, model.parameters())

    optimizer = torch.optim.SGD(param_groups,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    scheduler = MultiStepLR(optimizer, args.lr_steps, gamma=0.1)
    if args.evaluate:
        validate(val_loader, model, criterion, device)
        return
    if args.save_stats:
        if args.dataset != 'epic':
            stats_dict = {
                'train_loss': np.zeros((args.epochs, )),
                'val_loss': np.zeros((args.epochs, )),
                'train_acc': np.zeros((args.epochs, )),
                'val_acc': np.zeros((args.epochs, ))
            }
        elif args.dataset == 'epic':
            if args.train_list is not None:
                stats_dict = {
                    'train_loss': np.zeros((args.epochs, )),
                    'train_verb_loss': np.zeros((args.epochs, )),
                    'train_noun_loss': np.zeros((args.epochs, )),
                    'train_acc': np.zeros((args.epochs, )),
                    'train_verb_acc': np.zeros((args.epochs, )),
                    'train_noun_acc': np.zeros((args.epochs, )),
                    'val_loss': np.zeros((args.epochs, )),
                    'val_verb_loss': np.zeros((args.epochs, )),
                    'val_noun_loss': np.zeros((args.epochs, )),
                    'val_acc': np.zeros((args.epochs, )),
                    'val_verb_acc': np.zeros((args.epochs, )),
                    'val_noun_acc': np.zeros((args.epochs, ))
                }
            else:
                stats_dict = {
                    'train_loss': np.zeros((args.epochs, )),
                    'train_verb_loss': np.zeros((args.epochs, )),
                    'train_noun_loss': np.zeros((args.epochs, )),
                    'train_acc': np.zeros((args.epochs, )),
                    'train_verb_acc': np.zeros((args.epochs, )),
                    'train_noun_acc': np.zeros((args.epochs, ))
                }

    for epoch in range(args.start_epoch, args.epochs):
        scheduler.step()
        # train for one epoch
        training_metrics = train(train_loader, model, criterion, optimizer,
                                 epoch, device)
        if args.save_stats:
            for k, v in training_metrics.items():
                stats_dict[k][epoch] = v
        # evaluate on validation set
        if args.train_list is not None:
            if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
                test_metrics = validate(val_loader, model, criterion, device)
                if args.save_stats:
                    for k, v in test_metrics.items():
                        stats_dict[k][epoch] = v
                prec1 = test_metrics['val_acc']
                # remember best prec@1 and save checkpoint
                is_best = prec1 > best_prec1
                best_prec1 = max(prec1, best_prec1)
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'best_prec1': best_prec1,
                    }, is_best)
        else:  #  No validation set
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': training_metrics['train_acc'],
                }, False)

    summaryWriter.close()

    if args.save_stats:
        save_stats_dir = os.path.join('stats', experiment_dir)
        if not os.path.exists(save_stats_dir):
            os.makedirs(save_stats_dir)
        with open(os.path.join(save_stats_dir, 'training_stats.npz'),
                  'wb') as f:
            np.savez(f, **stats_dict)