Exemple #1
0
def Test(args):
    assert args.batch_size == 1  # large testing assume batch size one
    if args.gpus is not None:
        gpus = [int(x) for x in args.gpus.split(',')]
        num_gpus = len(gpus)
    else:
        gpus = range(args.num_gpus)
        num_gpus = args.num_gpus

    if num_gpus > 0:
        total_batch_size = args.batch_size * num_gpus
        log.info("Running on GPUs: {}".format(gpus))
        log.info("total_batch_size: {}".format(total_batch_size))
    else:
        total_batch_size = args.batch_size
        log.info("Running on CPU")
        log.info("total_batch_size: {}".format(total_batch_size))

    video_input_args = dict(
        batch_size=args.batch_size,
        clip_per_video=args.clip_per_video,
        decode_type=1,
        length_rgb=args.clip_length_rgb,
        sampling_rate_rgb=args.sampling_rate_rgb,
        scale_h=args.scale_h,
        scale_w=args.scale_w,
        crop_size=args.crop_size,
        video_res_type=args.video_res_type,
        short_edge=min(args.scale_h, args.scale_w),
        num_decode_threads=args.num_decode_threads,
        do_multi_label=args.multi_label,
        num_of_class=args.num_labels,
        random_mirror=False,
        random_crop=False,
        input_type=args.input_type,
        length_of=args.clip_length_of,
        sampling_rate_of=args.sampling_rate_of,
        frame_gap_of=args.frame_gap_of,
        do_flow_aggregation=args.do_flow_aggregation,
        flow_data_type=args.flow_data_type,
        get_rgb=(args.input_type == 0 or args.input_type >= 3),
        get_optical_flow=(args.input_type == 1 or args.input_type >= 4),
        use_local_file=args.use_local_file,
        crop_per_clip=args.crop_per_clip,
    )

    reader_args = dict(
        name="test_reader",
        input_data=args.test_data,
    )

    # Model building functions
    def create_model_ops(model, loss_scale):
        return model_builder.build_model(
            model=model,
            model_name=args.model_name,
            model_depth=args.model_depth,
            num_labels=args.num_labels,
            batch_size=args.batch_size * args.clip_per_video,
            num_channels=args.num_channels,
            crop_size=args.crop_size,
            clip_length=(args.clip_length_of
                         if args.input_type == 1 else args.clip_length_rgb),
            loss_scale=loss_scale,
            is_test=1,
            pred_layer_name=args.pred_layer_name,
            multi_label=args.multi_label,
            channel_multiplier=args.channel_multiplier,
            bottleneck_multiplier=args.bottleneck_multiplier,
            use_dropout=args.use_dropout,
            conv1_temporal_stride=args.conv1_temporal_stride,
            conv1_temporal_kernel=args.conv1_temporal_kernel,
            use_convolutional_pred=args.use_convolutional_pred,
            use_pool1=args.use_pool1,
        )

    def empty_function(model, loss_scale=1):
        # null
        return

    test_data_loader = cnn.CNNModelHelper(
        order="NCHW",
        name="data_loader",
    )
    test_model = cnn.CNNModelHelper(
        order="NCHW",
        name="video_model",
        use_cudnn=(True if args.use_cudnn == 1 else False),
        cudnn_exhaustive_search=True,
    )

    test_reader, number_of_examples = reader_utils.create_data_reader(
        test_data_loader, **reader_args)

    if args.num_iter <= 0:
        num_iter = int(math.ceil(number_of_examples / total_batch_size))
    else:
        num_iter = args.num_iter

    def test_input_fn(model):
        model_helper.AddVideoInput(test_data_loader, test_reader,
                                   **video_input_args)

    if num_gpus > 0:
        data_parallel_model.Parallelize_GPU(
            test_data_loader,
            input_builder_fun=test_input_fn,
            forward_pass_builder_fun=empty_function,
            param_update_builder_fun=None,
            devices=gpus,
            optimize_gradient_memory=True,
        )
        data_parallel_model.Parallelize_GPU(
            test_model,
            input_builder_fun=empty_function,
            forward_pass_builder_fun=create_model_ops,
            param_update_builder_fun=None,
            devices=gpus,
            optimize_gradient_memory=True,
        )
    else:
        test_model._device_type = caffe2_pb2.CPU
        test_model._devices = [0]
        device_opt = core.DeviceOption(test_model._device_type, 0)
        with core.DeviceScope(device_opt):
            # Because our loaded models are named with "gpu_x",
            # keep the naming for now.
            # TODO: Save model using `data_parallel_model.ExtractPredictorNet`
            # to extract the model for "gpu_0". It also renames
            # the input and output blobs by stripping the "gpu_x/" prefix
            with core.NameScope("{}_{}".format("gpu", 0)):
                test_input_fn(test_data_loader)
                create_model_ops(test_model, 1.0)

    workspace.RunNetOnce(test_data_loader.param_init_net)
    workspace.CreateNet(test_data_loader.net)
    workspace.RunNetOnce(test_model.param_init_net)
    workspace.CreateNet(test_model.net)

    if args.db_type == 'minidb':
        if num_gpus > 0:
            model_helper.LoadModel(args.load_model_path, args.db_type)
        else:
            with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU, 0)):
                model_helper.LoadModel(args.load_model_path, args.db_type)
    elif args.db_type == 'pickle':
        if num_gpus > 0:
            model_loader.LoadModelFromPickleFile(test_model,
                                                 args.load_model_path,
                                                 use_gpu=True,
                                                 root_gpu_id=gpus[0])
        else:
            model_loader.LoadModelFromPickleFile(test_model,
                                                 args.load_model_path,
                                                 use_gpu=False)
    else:
        log.warning("Unsupported db_type: {}".format(args.db_type))

    data_parallel_model.FinalizeAfterCheckpoint(test_model)

    # metric couters for multilabel
    all_prob_for_map = np.empty(shape=[0, args.num_labels], dtype=np.float)
    all_label_for_map = np.empty(shape=[0, args.num_labels], dtype=np.int32)

    # metric counters for closed-world classification
    clip_acc = 0
    video_top1 = 0
    video_topk = 0
    video_count = 0
    clip_count = 0

    num_devices = 1  # default for cpu
    if num_gpus > 0:
        num_devices = num_gpus
    # actual_batch_size
    inference_batch_size = args.crop_per_inference
    num_crop_per_bag = args.clip_per_video * args.crop_per_clip
    # make sure you do your math correctly
    assert num_crop_per_bag % num_crop_per_bag == 0
    num_slice = int(num_crop_per_bag / inference_batch_size)

    for i in range(num_iter):
        # load one batch of data assume 1 video
        # which is (#clips x #crops) x 3 x crop_size x crop_size
        workspace.RunNet(test_data_loader.net.Proto().name)

        # get all data into a list, each per device (gpu)
        video_data = []
        label_data = []
        all_predicts = []
        for g in range(num_devices):
            data = workspace.FetchBlob("gpu_{}".format(gpus[g]) + '/data')
            video_data.append(data)
            label = workspace.FetchBlob("gpu_{}".format(gpus[g]) + '/label')
            label_data.append(label)
            all_predicts.append([])

        for slice in range(num_slice):
            for g in range(num_devices):
                data = video_data[g][slice * inference_batch_size:(slice + 1) *
                                     inference_batch_size, :, :, :, :]
                if args.multi_label:
                    label = label_data[g][slice *
                                          inference_batch_size:(slice + 1) *
                                          inference_batch_size, :]
                else:
                    label = label_data[g][slice *
                                          inference_batch_size:(slice + 1) *
                                          inference_batch_size]
                workspace.FeedBlob("gpu_{}".format(gpus[g]) + '/data', data)
                workspace.FeedBlob("gpu_{}".format(gpus[g]) + '/label', label)

            # do one iteration of inference over one slice across devices
            workspace.RunNet(test_model.net.Proto().name)

            for g in range(num_devices):
                # get predictions
                if args.multi_label:
                    predicts = workspace.FetchBlob("gpu_{}".format(gpus[g]) +
                                                   '/prob')
                else:
                    predicts = workspace.FetchBlob("gpu_{}".format(gpus[g]) +
                                                   '/softmax')

                assert predicts.shape[0] == inference_batch_size

                # accumulate predictions
                if all_predicts[g] == []:
                    all_predicts[g] = predicts
                else:
                    all_predicts[g] = np.concatenate(
                        (all_predicts[g], predicts), axis=0)

        for g in range(num_devices):
            # get clip accuracy
            predicts = all_predicts[g]
            if args.multi_label:
                sample_label = label_data[g][0, :]
            else:
                sample_label = label_data[g][0]
            for k in range(num_crop_per_bag):
                sorted_preds = np.argsort(predicts[k, :])
                sorted_preds[:] = sorted_preds[::-1]
                if sorted_preds[0] == sample_label:
                    clip_acc = clip_acc + 1

            # since batch_size == 1
            all_clips = predicts
            # aggregate predictions into one
            video_pred = PredictionAggregation(all_clips, args.aggregation)
            if args.multi_label:
                video_pred = np.expand_dims(video_pred, axis=0)
                sample_label = np.expand_dims(sample_label, axis=0)
                all_prob_for_map = np.concatenate(
                    (all_prob_for_map, video_pred), axis=0)
                all_label_for_map = np.concatenate(
                    (all_label_for_map, sample_label), axis=0)
            else:
                sorted_video_pred = np.argsort(video_pred)
                sorted_video_pred[:] = sorted_video_pred[::-1]
                if sorted_video_pred[0] == sample_label:
                    video_top1 = video_top1 + 1
                if sample_label in sorted_video_pred[0:args.top_k]:
                    video_topk = video_topk + 1

        video_count = video_count + num_devices
        clip_count = clip_count + num_devices * num_crop_per_bag

        if i > 0 and i % args.display_iter == 0:
            if args.multi_label:
                # mAP
                auc, ap, wap, aps = metric.mean_ap_metric(
                    all_prob_for_map, all_label_for_map)
                log.info(
                    'Iter {}/{}: mAUC: {}, mAP: {}, mWAP: {}, mAP_all: {}'.
                    format(i, num_iter, auc, ap, wap, np.mean(aps)))
            else:
                # accuracy
                log.info('Iter {}/{}: clip: {}, top1: {}, top 5: {}'.format(
                    i, num_iter, clip_acc / clip_count,
                    video_top1 / video_count, video_topk / video_count))

    if args.multi_label:
        # mAP
        auc, ap, wap, aps = metric.mean_ap_metric(all_prob_for_map,
                                                  all_label_for_map)
        log.info("Test mAUC: {}, mAP: {}, mWAP: {}, mAP_all: {}".format(
            auc, ap, wap, np.mean(aps)))
        if args.print_per_class_metrics:
            log.info("Test mAP per class: {}".format(aps))
    else:
        # accuracy
        log.info("Test accuracy: clip: {}, top 1: {}, top{}: {}".format(
            clip_acc / clip_count, video_top1 / video_count, args.top_k,
            video_topk / video_count))

    if num_gpus > 0:
        flops, params, inters = model_helper.GetFlopsAndParams(
            test_model, gpus[0])
    else:
        flops, params, inters = model_helper.GetFlopsAndParams(test_model)
    log.info('FLOPs: {}, params: {}, inters: {}'.format(flops, params, inters))
Exemple #2
0
def Train(args):
    if args.gpus is not None:
        gpus = [int(x) for x in args.gpus.split(',')]
        num_gpus = len(gpus)
    else:
        gpus = range(args.num_gpus)
        num_gpus = args.num_gpus

    log.info("Running on GPUs: {}".format(gpus))

    # Modify to make it consistent with the distributed trainer
    total_batch_size = args.batch_size * num_gpus
    batch_per_device = args.batch_size

    # Round down epoch size to closest multiple of batch size across machines
    epoch_iters = int(args.epoch_size / total_batch_size)
    args.epoch_size = epoch_iters * total_batch_size
    log.info("Using epoch size: {}".format(args.epoch_size))

    # Create CNNModeLhelper object
    train_model = cnn.CNNModelHelper(
        order="NCHW",
        name='{}_train'.format(args.model_name),
        use_cudnn=(True if args.use_cudnn == 1 else False),
        cudnn_exhaustive_search=True,
        ws_nbytes_limit=(args.cudnn_workspace_limit_mb * 1024 * 1024),
    )

    # Model building functions
    def create_model_ops(model, loss_scale):
        return model_builder.build_model(
            model=model,
            model_name=args.model_name,
            model_depth=args.model_depth,
            num_labels=args.num_labels,
            batch_size=args.batch_size,
            num_channels=args.num_channels,
            crop_size=args.crop_size,
            clip_length=(args.clip_length_of
                         if args.input_type else args.clip_length_rgb),
            loss_scale=loss_scale,
            pred_layer_name=args.pred_layer_name,
            multi_label=args.multi_label,
            channel_multiplier=args.channel_multiplier,
            bottleneck_multiplier=args.bottleneck_multiplier,
            use_dropout=args.use_dropout,
            conv1_temporal_stride=args.conv1_temporal_stride,
            conv1_temporal_kernel=args.conv1_temporal_kernel,
            use_pool1=args.use_pool1,
            audio_input_3d=args.audio_input_3d,
            g_blend=args.g_blend,
            audio_weight=args.audio_weight,
            visual_weight=args.visual_weight,
            av_weight=args.av_weight,
        )

    # SGD
    def add_parameter_update_ops(model):
        model.AddWeightDecay(args.weight_decay)
        ITER = model.Iter("ITER")
        stepsz = args.step_epoch * args.epoch_size / args.batch_size / num_gpus
        LR = model.net.LearningRate(
            [ITER],
            "LR",
            base_lr=args.base_learning_rate * num_gpus,
            policy="step",
            stepsize=int(stepsz),
            gamma=args.gamma,
        )
        AddMomentumParameterUpdate(model, LR)

    # Input. Note that the reader must be shared with all GPUS.
    train_reader, train_examples = reader_utils.create_data_reader(
        train_model,
        name="train_reader",
        input_data=args.train_data,
    )
    log.info("Training set has {} examples".format(train_examples))

    def add_video_input(model):
        model_helper.AddVideoInput(
            model,
            train_reader,
            batch_size=batch_per_device,
            length_rgb=args.clip_length_rgb,
            clip_per_video=1,
            random_mirror=True,
            decode_type=0,
            sampling_rate_rgb=args.sampling_rate_rgb,
            scale_h=args.scale_h,
            scale_w=args.scale_w,
            crop_size=args.crop_size,
            video_res_type=args.video_res_type,
            short_edge=min(args.scale_h, args.scale_w),
            num_decode_threads=args.num_decode_threads,
            do_multi_label=args.multi_label,
            num_of_class=args.num_labels,
            random_crop=True,
            input_type=args.input_type,
            length_of=args.clip_length_of,
            sampling_rate_of=args.sampling_rate_of,
            frame_gap_of=args.frame_gap_of,
            do_flow_aggregation=args.do_flow_aggregation,
            flow_data_type=args.flow_data_type,
            get_rgb=(args.input_type == 0 or args.input_type >= 3),
            get_optical_flow=(args.input_type == 1 or args.input_type >= 4),
            get_logmels=(args.input_type >= 2),
            get_video_id=args.get_video_id,
            jitter_scales=[int(n) for n in args.jitter_scales.split(',')],
            use_local_file=args.use_local_file,
        )

    # Create parallelized model
    data_parallel_model.Parallelize_GPU(
        train_model,
        input_builder_fun=add_video_input,
        forward_pass_builder_fun=create_model_ops,
        param_update_builder_fun=add_parameter_update_ops,
        devices=gpus,
        rendezvous=None,
        net_type=('prof_dag' if args.profiling == 1 else 'dag'),
        optimize_gradient_memory=True,
    )

    # Add test model, if specified
    test_model = None
    if args.test_data is not None:
        log.info("----- Create test net ----")
        test_model = cnn.CNNModelHelper(
            order="NCHW",
            name='{}_test'.format(args.model_name),
            use_cudnn=(True if args.use_cudnn == 1 else False),
            cudnn_exhaustive_search=True)

        test_reader, test_examples = reader_utils.create_data_reader(
            test_model,
            name="test_reader",
            input_data=args.test_data,
        )

        log.info("Testing set has {} examples".format(test_examples))

        def test_input_fn(model):
            model_helper.AddVideoInput(
                model,
                test_reader,
                batch_size=batch_per_device,
                length_rgb=args.clip_length_rgb,
                clip_per_video=1,
                decode_type=0,
                random_mirror=False,
                random_crop=False,
                sampling_rate_rgb=args.sampling_rate_rgb,
                scale_h=args.scale_h,
                scale_w=args.scale_w,
                crop_size=args.crop_size,
                video_res_type=args.video_res_type,
                short_edge=min(args.scale_h, args.scale_w),
                num_decode_threads=args.num_decode_threads,
                do_multi_label=args.multi_label,
                num_of_class=args.num_labels,
                input_type=args.input_type,
                length_of=args.clip_length_of,
                sampling_rate_of=args.sampling_rate_of,
                frame_gap_of=args.frame_gap_of,
                do_flow_aggregation=args.do_flow_aggregation,
                flow_data_type=args.flow_data_type,
                get_rgb=(args.input_type == 0),
                get_optical_flow=(args.input_type == 1),
                get_video_id=args.get_video_id,
                use_local_file=args.use_local_file,
            )

        data_parallel_model.Parallelize_GPU(
            test_model,
            input_builder_fun=test_input_fn,
            forward_pass_builder_fun=create_model_ops,
            param_update_builder_fun=None,
            devices=gpus,
            optimize_gradient_memory=True,
        )
        workspace.RunNetOnce(test_model.param_init_net)
        workspace.CreateNet(test_model.net)

    workspace.RunNetOnce(train_model.param_init_net)
    workspace.CreateNet(train_model.net)

    epoch = 0
    # load the pre-trained model and reset epoch
    if args.load_model_path is not None:
        if args.db_type == 'pickle':
            model_loader.LoadModelFromPickleFile(train_model,
                                                 args.load_model_path,
                                                 use_gpu=True,
                                                 root_gpu_id=gpus[0])
        else:
            model_helper.LoadModel(args.load_model_path, args.db_type)
        # Sync the model params
        data_parallel_model.FinalizeAfterCheckpoint(
            train_model,
            GetCheckpointParams(train_model),
        )

        if args.is_checkpoint:
            # reset epoch. load_model_path should end with *_X.mdl,
            # where X is the epoch number
            last_str = args.load_model_path.split('_')[-1]
            if last_str.endswith('.mdl'):
                epoch = int(last_str[:-4])
                log.info("Reset epoch to {}".format(epoch))
            else:
                log.warning("The format of load_model_path doesn't match!")

    expname = "%s_gpu%d_b%d_L%d_lr%.2f" % (
        args.model_name,
        args.num_gpus,
        total_batch_size,
        args.num_labels,
        args.base_learning_rate,
    )
    explog = experiment_util.ModelTrainerLog(expname, args)

    # Run the training one epoch a time
    while epoch < args.num_epochs:
        epoch = RunEpoch(args, epoch, train_model, test_model,
                         total_batch_size, 1, expname, explog)

        # Save the model for each epoch
        SaveModel(args, train_model, epoch)
Exemple #3
0
def ExtractFeatures(args):
    if args.gpus is not None:
        gpus = [int(x) for x in args.gpus.split(',')]
        num_gpus = len(gpus)
    else:
        gpus = range(args.num_gpus)
        num_gpus = args.num_gpus

    if num_gpus > 0:
        log.info("Running on GPUs: {}".format(gpus))
    else:
        log.info("Running on CPU")

    my_arg_scope = {
        'order': 'NCHW',
        'use_cudnn': True,
        'cudnn_exhaustive_search': True
    }

    model = cnn.CNNModelHelper(name="Extract Features", **my_arg_scope)

    video_input_args = dict(
        batch_size=args.batch_size,
        clip_per_video=args.clip_per_video,
        decode_type=args.decode_type,
        length_rgb=args.clip_length_rgb,
        sampling_rate_rgb=args.sampling_rate_rgb,
        scale_h=args.scale_h,
        scale_w=args.scale_w,
        crop_size=args.crop_size,
        video_res_type=args.video_res_type,
        short_edge=min(args.scale_h, args.scale_w),
        num_decode_threads=args.num_decode_threads,
        do_multi_label=args.multi_label,
        num_of_class=args.num_labels,
        random_mirror=False,
        random_crop=False,
        input_type=args.input_type,
        length_of=args.clip_length_of,
        sampling_rate_of=args.sampling_rate_of,
        frame_gap_of=args.frame_gap_of,
        do_flow_aggregation=args.do_flow_aggregation,
        flow_data_type=args.flow_data_type,
        get_rgb=args.input_type == 0,
        get_optical_flow=args.input_type == 1,
        get_video_id=args.get_video_id,
        get_start_frame=args.get_start_frame,
        use_local_file=args.use_local_file,
        crop_per_clip=args.crop_per_clip,
    )

    reader_args = dict(
        name="extract_features" + '_reader',
        input_data=args.test_data,
    )

    reader, num_examples = reader_utils.create_data_reader(
        model, **reader_args)

    def input_fn(model):
        model_helper.AddVideoInput(model, reader, **video_input_args)

    def create_model_ops(model, loss_scale):
        return model_builder.build_model(
            model=model,
            model_name=args.model_name,
            model_depth=args.model_depth,
            num_labels=args.num_labels,
            batch_size=args.batch_size,
            num_channels=args.num_channels,
            crop_size=args.crop_size,
            clip_length=(args.clip_length_of
                         if args.input_type == 1 else args.clip_length_rgb),
            loss_scale=loss_scale,
            is_test=1,
            multi_label=args.multi_label,
            channel_multiplier=args.channel_multiplier,
            bottleneck_multiplier=args.bottleneck_multiplier,
            use_dropout=args.use_dropout,
            use_convolutional_pred=args.use_convolutional_pred,
            use_pool1=args.use_pool1,
        )

    if num_gpus > 0:
        data_parallel_model.Parallelize_GPU(
            model,
            input_builder_fun=input_fn,
            forward_pass_builder_fun=create_model_ops,
            param_update_builder_fun=None,  # 'None' since we aren't training
            devices=gpus,
            optimize_gradient_memory=True,
        )
    else:
        model._device_type = caffe2_pb2.CPU
        model._devices = [0]
        device_opt = core.DeviceOption(model._device_type, 0)
        with core.DeviceScope(device_opt):
            # Because our loaded models are named with "gpu_x", keep the naming for now.
            # TODO: Save model using `data_parallel_model.ExtractPredictorNet`
            # to extract the model for "gpu_0". It also renames
            # the input and output blobs by stripping the "gpu_x/" prefix
            with core.NameScope("{}_{}".format("gpu", 0)):
                input_fn(model)
                create_model_ops(model, 1.0)

    workspace.RunNetOnce(model.param_init_net)
    workspace.CreateNet(model.net)

    if args.db_type == 'pickle':
        model_loader.LoadModelFromPickleFile(model, args.load_model_path)
    elif args.db_type == 'minidb':
        if num_gpus > 0:
            model_helper.LoadModel(args.load_model_path, args.db_type)
        else:
            with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU, 0)):
                model_helper.LoadModel(args.load_model_path, args.db_type)
    else:
        log.warning("Unsupported db_type: {}".format(args.db_type))

    data_parallel_model.FinalizeAfterCheckpoint(model)

    def fetchActivations(model, outputs, num_iterations):

        all_activations = {}
        for counter in range(num_iterations):
            workspace.RunNet(model.net.Proto().name)

            num_devices = 1  # default for cpu
            if num_gpus > 0:
                num_devices = num_gpus
            for g in range(num_devices):
                for output_name in outputs:
                    blob_name = 'gpu_{}/'.format(g) + output_name
                    activations = workspace.FetchBlob(blob_name)
                    if output_name not in all_activations:
                        all_activations[output_name] = []
                    all_activations[output_name].append(activations)

            if counter % 20 == 0:
                log.info('{}/{} iterations'.format(counter, num_iterations))

        # each key holds a list of activations obtained from each minibatch.
        # we now concatenate these lists to get the final arrays.
        # concatenating during the loop requires a realloc and can get slow.
        for key in all_activations:
            all_activations[key] = np.concatenate(all_activations[key])

        return all_activations

    outputs = [name.strip() for name in args.features.split(',')]
    assert len(outputs) > 0

    if args.num_iterations > 0:
        num_iterations = args.num_iterations
    else:
        if num_gpus > 0:
            examples_per_iteration = args.batch_size * num_gpus
        else:
            examples_per_iteration = args.batch_size
        num_iterations = int(num_examples / examples_per_iteration)

    activations = fetchActivations(model, outputs, num_iterations)

    # saving extracted features
    for index in range(len(outputs)):
        log.info("Read '{}' with shape {}".format(
            outputs[index], activations[outputs[index]].shape))

    if args.output_path:
        output_path = args.output_path
    else:
        output_path = os.path.dirname(args.test_data) + '/features.pickle'

    log.info('Writing to {}'.format(output_path))
    if args.save_h5:
        with h5py.File(output_path, 'w') as handle:
            for name, activation in activations.items():
                handle.create_dataset(name, data=activation)
    else:
        with open(output_path, 'wb') as handle:
            pickle.dump(activations, handle)

    # perform sanity check
    if args.sanity_check == 1:  # check clip accuracy
        assert args.multi_label == 0
        clip_acc = 0
        softmax = activations['softmax']
        label = activations['label']
        for i in range(len(softmax)):
            sorted_preds = \
                np.argsort(softmax[i])
            sorted_preds[:] = sorted_preds[::-1]
            if sorted_preds[0] == label[i]:
                clip_acc += 1
        log.info('Sanity check --- clip accuracy: {}'.format(clip_acc /
                                                             len(softmax)))
    elif args.sanity_check == 2:  # check auc
        assert args.multi_label == 1
        prob = activations['prob']
        label = activations['label']
        mean_auc, mean_ap, mean_wap, _ = metric.mean_ap_metric(prob, label)
        log.info('Sanity check --- AUC: {}, mAP: {}, mWAP: {}'.format(
            mean_auc, mean_ap, mean_wap))