Пример #1
0
def explain_model_on_dataset(
    model,
    test_assets_selected_indexs,
    test_dataset_filepath,
    result_store_dir=VmafConfig.file_result_store_path()):
    def print_assets(test_assets):
        print('\n'.join(
            map(
                lambda tasset: "Asset {i}: {name}".format(
                    i=tasset[0],
                    name=get_file_name_without_extension(tasset[1].dis_path)),
                enumerate(test_assets))))

    test_dataset = import_python_file(test_dataset_filepath)
    test_assets = read_dataset(test_dataset)
    print_assets(test_assets)
    print("Assets selected for local explanation: {}".format(
        test_assets_selected_indexs))
    result_store = FileSystemResultStore(result_store_dir)
    test_assets = [test_assets[i] for i in test_assets_selected_indexs]
    test_fassembler = FeatureAssembler(
        feature_dict=model.model_dict['feature_dict'],
        feature_option_dict=None,
        assets=test_assets,
        logger=None,
        fifo_mode=True,
        delete_workdir=True,
        result_store=result_store,
        optional_dict=None,
        optional_dict2=None,
        parallelize=True,
    )
    test_fassembler.run()
    test_feature_results = test_fassembler.results
    test_xs = model.get_xs_from_results(test_feature_results)
    test_ys = model.get_ys_from_results(test_feature_results)
    test_ys_pred = model.predict(test_xs)['ys_label_pred']
    explainer = LocalExplainer(neighbor_samples=1000)
    test_exps = explainer.explain(model, test_xs)

    explainer.print_explanations(test_exps,
                                 assets=test_assets,
                                 ys=test_ys,
                                 ys_pred=test_ys_pred)
    explainer.plot_explanations(test_exps,
                                assets=test_assets,
                                ys=test_ys,
                                ys_pred=test_ys_pred)
    DisplayConfig.show()
Пример #2
0
def main():
    if len(sys.argv) < 3:
        print_usage()
        return 2

    try:
        subjective_model = sys.argv[1]
        dataset_filepath = sys.argv[2]
    except ValueError:
        print_usage()
        return 2

    try:
        subjective_model_class = SubjectiveModel.find_subclass(
            subjective_model)
    except Exception as e:
        print "Error: " + str(e)
        return 1

    print "Run model {} on dataset {}".format(
        subjective_model_class.__name__,
        get_file_name_with_extension(dataset_filepath))

    run_subjective_models(
        dataset_filepath=dataset_filepath,
        subjective_model_classes=[
            subjective_model_class,
        ],
        normalize_final=False,  # True or False
        do_plot=[
            'raw_scores',
            'quality_scores',
            'subject_scores',
            'content_scores',
        ],
        plot_type='errorbar',
        gradient_method='simplified',
    )

    DisplayConfig.show()

    return 0
Пример #3
0
def main():

    subsamples = [1, 2, 4, 8, 16, 32, 64, 128]
    elapsed_times = []
    pccs = []
    for subsample in subsamples:
        elapsed_time, srcc, pcc, rmse = run_vmafossexec_with_subsample(
            VmafConfig.resource_path('dataset', 'NFLX_dataset_public.py'), subsample)
        elapsed_times.append(elapsed_time)
        pccs.append(pcc)
        print("SRCC: {}, PCC: {}, RMSE: {}, time: {}".format(srcc, pcc, rmse, elapsed_time))

    fig, ax = plt.subplots(1, 1, figsize=[8, 5])
    ax.plot(subsamples, 6*24*79 / np.array(elapsed_times), 'x-')
    ax.set_xlabel("Subsample")
    ax.set_ylabel("Processing Speed (Frms/Sec)")
    ax.grid(True)

    plt.tight_layout()

    DisplayConfig.show()
Пример #4
0
    def test_test_on_dataset_plot_per_content(self):
        from vmaf.routine import run_test_on_dataset
        test_dataset = import_python_file(
            VmafConfig.test_resource_path('dataset_sample.py'))
        import matplotlib.pyplot as plt
        fig, ax = plt.subplots(1, 1, figsize=[20, 20])
        run_test_on_dataset(test_dataset, VmafQualityRunner, ax,
                            None, VmafConfig.model_path("vmaf_float_v0.6.1.json"),
                            parallelize=False,
                            fifo_mode=False,
                            aggregate_method=None,
                            point_label='asset_id',
                            do_plot=['aggregate',  # plots all contents in one figure
                                     'per_content'  # plots a separate figure per content
                                     ],
                            plot_linear_fit=True  # adds linear fit line to each plot
                            )

        output_dir = VmafConfig.workspace_path("output", "test_output")
        DisplayConfig.show(write_to_dir=output_dir)
        self.assertEqual(len(glob.glob(os.path.join(output_dir, '*.png'))), 3)

        if os.path.exists(output_dir):
            shutil.rmtree(output_dir)
Пример #5
0
def main():
    if len(sys.argv) < 6:
        print_usage()
        return 2

    try:
        fmt = sys.argv[1]
        width = int(sys.argv[2])
        height = int(sys.argv[3])
        ref_file = sys.argv[4]
        dis_file = sys.argv[5]
    except ValueError:
        print_usage()
        return 2

    if width < 0 or height < 0:
        print("width and height must be non-negative, but are {w} and {h}".format(w=width, h=height))
        print_usage()
        return 2

    if fmt not in FMTS:
        print_usage()
        return 2

    model_path = get_cmd_option(sys.argv, 6, len(sys.argv), '--model')
    obj_file = get_cmd_option(sys.argv, 6, len(sys.argv), '--attention')

    out_fmt = get_cmd_option(sys.argv, 6, len(sys.argv), '--out-fmt')
    if not (out_fmt is None
            or out_fmt == 'xml'
            or out_fmt == 'json'
            or out_fmt == 'text'):
        print_usage()
        return 2

    pool_method = get_cmd_option(sys.argv, 6, len(sys.argv), '--pool')
    if not (pool_method is None
            or pool_method in POOL_METHODS):
        print('--pool can only have option among {}'.format(', '.join(POOL_METHODS)))
        return 2

    show_local_explanation = cmd_option_exists(sys.argv, 6, len(sys.argv), '--local-explain')

    phone_model = cmd_option_exists(sys.argv, 6, len(sys.argv), '--phone-model')

    enable_conf_interval = cmd_option_exists(sys.argv, 6, len(sys.argv), '--ci')

    save_plot_dir = get_cmd_option(sys.argv, 6, len(sys.argv), '--save-plot')

    if show_local_explanation and enable_conf_interval:
        print('cannot set both --local-explain and --ci flags')
        return 2

    asset = Asset(dataset="cmd",
                  content_id=abs(hash(get_file_name_without_extension(ref_file))) % (10 ** 16),
                  asset_id=abs(hash(get_file_name_without_extension(ref_file))) % (10 ** 16),
                  workdir_root=VmafConfig.workdir_path(),
                  ref_path=ref_file,
                  dis_path=dis_file,
                  obj_path=obj_file,
                  asset_dict={'width':width, 'height':height, 'yuv_type':fmt}
                  )
    assets = [asset]

    if show_local_explanation:
        from vmaf.core.quality_runner_extra import VmafQualityRunnerWithLocalExplainer
        runner_class = VmafQualityRunnerWithLocalExplainer
    elif enable_conf_interval:
        from vmaf.core.quality_runner import BootstrapVmafQualityRunner
        runner_class = BootstrapVmafQualityRunner
    else:
        runner_class = VmafQualityRunner
        #runner_class=PsnrQualityRunner

    if model_path is None:
        optional_dict = None
    else:
        optional_dict = {'model_filepath':model_path}

    if phone_model:
        if optional_dict is None:
            optional_dict = {}
        optional_dict['enable_transform_score'] = True

    runner = runner_class(
        assets, None, fifo_mode=True,
        delete_workdir=True,
        result_store=None,
        optional_dict=optional_dict,
        optional_dict2=None,
    )

    # run
    runner.run()
    result = runner.results[0]

    # pooling
    if pool_method == 'harmonic_mean':
        result.set_score_aggregate_method(ListStats.harmonic_mean)
    elif pool_method == 'min':
        result.set_score_aggregate_method(np.min)
    elif pool_method == 'median':
        result.set_score_aggregate_method(np.median)
    elif pool_method == 'perc5':
        result.set_score_aggregate_method(ListStats.perc5)
    elif pool_method == 'perc10':
        result.set_score_aggregate_method(ListStats.perc10)
    elif pool_method == 'perc20':
        result.set_score_aggregate_method(ListStats.perc20)
    else: # None or 'mean'
        pass

    # output
    if out_fmt == 'xml':
        print(result.to_xml())
    elif out_fmt == 'json':
        print(result.to_json())
    else: # None or 'text'
        print(str(result))

    # local explanation
    if show_local_explanation:
        runner.show_local_explanations([result])

        if save_plot_dir is None:
            DisplayConfig.show()
        else:
            DisplayConfig.show(write_to_dir=save_plot_dir)

    return 0
Пример #6
0
def main():
    if len(sys.argv) < 3:
        print_usage()
        return 2

    try:
        quality_type = sys.argv[1]
        test_dataset_filepath = sys.argv[2]
    except ValueError:
        print_usage()
        return 2

    vmaf_model_path = get_cmd_option(sys.argv, 3, len(sys.argv),
                                     '--vmaf-model')
    cache_result = cmd_option_exists(sys.argv, 3, len(sys.argv),
                                     '--cache-result')
    parallelize = cmd_option_exists(sys.argv, 3, len(sys.argv),
                                    '--parallelize')
    print_result = cmd_option_exists(sys.argv, 3, len(sys.argv),
                                     '--print-result')
    suppress_plot = cmd_option_exists(sys.argv, 3, len(sys.argv),
                                      '--suppress-plot')
    vmaf_phone_model = cmd_option_exists(sys.argv, 3, len(sys.argv),
                                         '--vmaf-phone-model')

    pool_method = get_cmd_option(sys.argv, 3, len(sys.argv), '--pool')
    if not (pool_method is None or pool_method in POOL_METHODS):
        print '--pool can only have option among {}'.format(
            ', '.join(POOL_METHODS))
        return 2

    subj_model = get_cmd_option(sys.argv, 3, len(sys.argv), '--subj-model')

    try:
        if subj_model is not None:
            from sureal.subjective_model import SubjectiveModel
            subj_model_class = SubjectiveModel.find_subclass(subj_model)
        else:
            subj_model_class = None
    except Exception as e:
        print "Error: " + str(e)
        return 1

    save_plot_dir = get_cmd_option(sys.argv, 3, len(sys.argv), '--save-plot')

    try:
        runner_class = QualityRunner.find_subclass(quality_type)
    except Exception as e:
        print "Error: " + str(e)
        return 1

    if vmaf_model_path is not None and runner_class != VmafQualityRunner and \
                    runner_class != BootstrapVmafQualityRunner:
        print "Input error: only quality_type of VMAF accepts --vmaf-model."
        print_usage()
        return 2

    if vmaf_phone_model and runner_class != VmafQualityRunner and \
                    runner_class != BootstrapVmafQualityRunner:
        print "Input error: only quality_type of VMAF accepts --vmaf-phone-model."
        print_usage()
        return 2

    try:
        test_dataset = import_python_file(test_dataset_filepath)
    except Exception as e:
        print "Error: " + str(e)
        return 1

    if cache_result:
        result_store = FileSystemResultStore()
    else:
        result_store = None

    # pooling
    if pool_method == 'harmonic_mean':
        aggregate_method = ListStats.harmonic_mean
    elif pool_method == 'min':
        aggregate_method = np.min
    elif pool_method == 'median':
        aggregate_method = np.median
    elif pool_method == 'perc5':
        aggregate_method = ListStats.perc5
    elif pool_method == 'perc10':
        aggregate_method = ListStats.perc10
    elif pool_method == 'perc20':
        aggregate_method = ListStats.perc20
    else:  # None or 'mean'
        aggregate_method = np.mean

    if vmaf_phone_model:
        enable_transform_score = True
    else:
        enable_transform_score = None

    try:
        if suppress_plot:
            raise AssertionError

        import matplotlib.pyplot as plt
        fig, ax = plt.subplots(figsize=(5, 5), nrows=1, ncols=1)

        assets, results = run_test_on_dataset(
            test_dataset,
            runner_class,
            ax,
            result_store,
            vmaf_model_path,
            parallelize=parallelize,
            aggregate_method=aggregate_method,
            subj_model_class=subj_model_class,
            enable_transform_score=enable_transform_score)

        bbox = {'facecolor': 'white', 'alpha': 0.5, 'pad': 20}
        ax.annotate('Testing Set',
                    xy=(0.1, 0.85),
                    xycoords='axes fraction',
                    bbox=bbox)

        # ax.set_xlim([-10, 110])
        # ax.set_ylim([-10, 110])

        plt.tight_layout()

        if save_plot_dir is None:
            DisplayConfig.show()
        else:
            DisplayConfig.show(write_to_dir=save_plot_dir)

    except ImportError:
        print_matplotlib_warning()
        assets, results = run_test_on_dataset(
            test_dataset,
            runner_class,
            None,
            result_store,
            vmaf_model_path,
            parallelize=parallelize,
            aggregate_method=aggregate_method,
            subj_model_class=subj_model_class,
            enable_transform_score=enable_transform_score)
    except AssertionError:
        assets, results = run_test_on_dataset(
            test_dataset,
            runner_class,
            None,
            result_store,
            vmaf_model_path,
            parallelize=parallelize,
            aggregate_method=aggregate_method,
            subj_model_class=subj_model_class,
            enable_transform_score=enable_transform_score)

    if print_result:
        for result in results:
            print result
            print ''

    return 0
Пример #7
0
def main():

    if len(sys.argv) < 5:
        print_usage()
        return 2

    try:
        train_dataset_filepath = sys.argv[1]
        feature_param_filepath = sys.argv[2]
        model_param_filepath = sys.argv[3]
        output_model_filepath = sys.argv[4]
    except ValueError:
        print_usage()
        return 2

    try:
        train_dataset = import_python_file(train_dataset_filepath)
        feature_param = import_python_file(feature_param_filepath)
        model_param = import_python_file(model_param_filepath)
    except Exception as e:
        print("Error: %s" % e)
        return 1

    cache_result = cmd_option_exists(sys.argv, 3, len(sys.argv),
                                     '--cache-result')
    parallelize = cmd_option_exists(sys.argv, 3, len(sys.argv),
                                    '--parallelize')
    suppress_plot = cmd_option_exists(sys.argv, 3, len(sys.argv),
                                      '--suppress-plot')

    pool_method = get_cmd_option(sys.argv, 3, len(sys.argv), '--pool')
    if not (pool_method is None or pool_method in POOL_METHODS):
        print('--pool can only have option among {}'.format(
            ', '.join(POOL_METHODS)))
        return 2

    subj_model = get_cmd_option(sys.argv, 3, len(sys.argv), '--subj-model')

    try:
        if subj_model is not None:
            from sureal.subjective_model import SubjectiveModel
            subj_model_class = SubjectiveModel.find_subclass(subj_model)
        else:
            subj_model_class = None
    except Exception as e:
        print("Error: %s" % e)
        return 1

    save_plot_dir = get_cmd_option(sys.argv, 3, len(sys.argv), '--save-plot')

    if cache_result:
        result_store = FileSystemResultStore()
    else:
        result_store = None

    # pooling
    if pool_method == 'harmonic_mean':
        aggregate_method = ListStats.harmonic_mean
    elif pool_method == 'min':
        aggregate_method = np.min
    elif pool_method == 'median':
        aggregate_method = np.median
    elif pool_method == 'perc5':
        aggregate_method = ListStats.perc5
    elif pool_method == 'perc10':
        aggregate_method = ListStats.perc10
    elif pool_method == 'perc20':
        aggregate_method = ListStats.perc20
    else:  # None or 'mean'
        aggregate_method = np.mean

    logger = None

    try:
        if suppress_plot:
            raise AssertionError

        from vmaf import plt
        fig, ax = plt.subplots(figsize=(5, 5), nrows=1, ncols=1)

        train_test_vmaf_on_dataset(
            train_dataset=train_dataset,
            test_dataset=None,
            feature_param=feature_param,
            model_param=model_param,
            train_ax=ax,
            test_ax=None,
            result_store=result_store,
            parallelize=parallelize,
            logger=logger,
            output_model_filepath=output_model_filepath,
            aggregate_method=aggregate_method,
            subj_model_class=subj_model_class,
        )

        bbox = {'facecolor': 'white', 'alpha': 0.5, 'pad': 20}
        ax.annotate('Training Set',
                    xy=(0.1, 0.85),
                    xycoords='axes fraction',
                    bbox=bbox)

        # ax.set_xlim([-10, 110])
        # ax.set_ylim([-10, 110])

        plt.tight_layout()

        if save_plot_dir is None:
            DisplayConfig.show()
        else:
            DisplayConfig.show(write_to_dir=save_plot_dir)

    except ImportError:
        print_matplotlib_warning()
        train_test_vmaf_on_dataset(
            train_dataset=train_dataset,
            test_dataset=None,
            feature_param=feature_param,
            model_param=model_param,
            train_ax=None,
            test_ax=None,
            result_store=result_store,
            parallelize=parallelize,
            logger=logger,
            output_model_filepath=output_model_filepath,
            aggregate_method=aggregate_method,
            subj_model_class=subj_model_class,
        )
    except AssertionError:
        train_test_vmaf_on_dataset(
            train_dataset=train_dataset,
            test_dataset=None,
            feature_param=feature_param,
            model_param=model_param,
            train_ax=None,
            test_ax=None,
            result_store=result_store,
            parallelize=parallelize,
            logger=logger,
            output_model_filepath=output_model_filepath,
            aggregate_method=aggregate_method,
            subj_model_class=subj_model_class,
        )

    return 0
Пример #8
0
def main():
    if len(sys.argv) < 5:
        print_usage()
        return 2

    try:
        q_width = int(sys.argv[1])
        q_height = int(sys.argv[2])
        ref_file = sys.argv[3]
        dis_file = sys.argv[4]
    except ValueError:
        print_usage()
        return 2

    if q_width < 0 or q_height < 0:
        print "quality_width and quality_height must be non-negative, but are {w} and {h}".format(
            w=q_width, h=q_height)
        print_usage()
        return 2

    model_path = get_cmd_option(sys.argv, 5, len(sys.argv), '--model')

    out_fmt = get_cmd_option(sys.argv, 5, len(sys.argv), '--out-fmt')
    if not (out_fmt is None or out_fmt == 'xml' or out_fmt == 'json'
            or out_fmt == 'text'):
        print_usage()
        return 2

    ref_fmt = get_cmd_option(sys.argv, 5, len(sys.argv), '--ref-fmt')
    if not (ref_fmt is None or ref_fmt in FMTS):
        print '--ref-fmt can only have option among {}'.format(', '.join(FMTS))

    ref_width = get_cmd_option(sys.argv, 5, len(sys.argv), '--ref-width')
    ref_height = get_cmd_option(sys.argv, 5, len(sys.argv), '--ref-height')
    dis_width = get_cmd_option(sys.argv, 5, len(sys.argv), '--dis-width')
    dis_height = get_cmd_option(sys.argv, 5, len(sys.argv), '--dis-height')

    dis_fmt = get_cmd_option(sys.argv, 5, len(sys.argv), '--dis-fmt')
    if not (dis_fmt is None or dis_fmt in FMTS):
        print '--dis-fmt can only have option among {}'.format(', '.join(FMTS))

    work_dir = get_cmd_option(sys.argv, 5, len(sys.argv), '--work-dir')

    pool_method = get_cmd_option(sys.argv, 5, len(sys.argv), '--pool')
    if not (pool_method is None or pool_method in POOL_METHODS):
        print '--pool can only have option among {}'.format(
            ', '.join(POOL_METHODS))
        return 2

    show_local_explanation = cmd_option_exists(sys.argv, 5, len(sys.argv),
                                               '--local-explain')

    phone_model = cmd_option_exists(sys.argv, 5, len(sys.argv),
                                    '--phone-model')

    if work_dir is None:
        work_dir = VmafConfig.workdir_path()

    asset_dict = {'quality_width': q_width, 'quality_height': q_height}

    if ref_fmt is None:
        asset_dict['ref_yuv_type'] = 'notyuv'
    else:
        if ref_width is None or ref_height is None:
            print 'if --ref-fmt is specified, both --ref-width and --ref-height must be specified'
            return 2
        else:
            asset_dict['ref_yuv_type'] = ref_fmt
            asset_dict['ref_width'] = ref_width
            asset_dict['ref_height'] = ref_height

    if dis_fmt is None:
        asset_dict['dis_yuv_type'] = 'notyuv'
    else:
        if dis_width is None or dis_height is None:
            print 'if --dis-fmt is specified, both --dis-width and --dis-height must be specified'
            return 2
        else:
            asset_dict['dis_yuv_type'] = dis_fmt
            asset_dict['dis_width'] = dis_width
            asset_dict['dis_height'] = dis_height

    asset = Asset(
        dataset="cmd",
        content_id=abs(hash(get_file_name_without_extension(ref_file))) %
        (10**16),
        asset_id=abs(hash(get_file_name_without_extension(ref_file))) %
        (10**16),
        workdir_root=work_dir,
        ref_path=ref_file,
        dis_path=dis_file,
        asset_dict=asset_dict,
    )
    assets = [asset]

    if not show_local_explanation:
        runner_class = VmafQualityRunner
    else:
        runner_class = VmafQualityRunnerWithLocalExplainer

    if model_path is None:
        optional_dict = None
    else:
        optional_dict = {'model_filepath': model_path}

    if phone_model:
        if optional_dict is None:
            optional_dict = {}
        optional_dict['enable_transform_score'] = True

    runner = runner_class(
        assets,
        None,
        fifo_mode=True,
        delete_workdir=True,
        result_store=None,
        optional_dict=optional_dict,
        optional_dict2=None,
    )

    # run
    runner.run()
    result = runner.results[0]

    # pooling
    if pool_method == 'harmonic_mean':
        result.set_score_aggregate_method(ListStats.harmonic_mean)
    elif pool_method == 'min':
        result.set_score_aggregate_method(np.min)
    elif pool_method == 'median':
        result.set_score_aggregate_method(np.median)
    elif pool_method == 'perc5':
        result.set_score_aggregate_method(ListStats.perc5)
    elif pool_method == 'perc10':
        result.set_score_aggregate_method(ListStats.perc10)
    elif pool_method == 'perc20':
        result.set_score_aggregate_method(ListStats.perc20)
    else:  # None or 'mean'
        pass

    # output
    if out_fmt == 'xml':
        print result.to_xml()
    elif out_fmt == 'json':
        print result.to_json()
    else:  # None or 'text'
        print str(result)

    # local explanation
    if show_local_explanation:
        import matplotlib.pyplot as plt
        runner.show_local_explanations([result])
        DisplayConfig.show()

    return 0
Пример #9
0
        output_model_filepath=VmafConfig.workspace_path(
            'model', 'test_model1.pkl'),
    )

    # ==== Run cross validation across genres (tough test) ====

    nflx_dataset_path = VmafConfig.resource_path('dataset',
                                                 'NFLX_dataset_public.py')
    contentid_groups = [
        [0, 5],  # cartoon: BigBuckBunny, FoxBird
        [1],  # CG: BirdsInCage
        [2, 6, 7],  # complex: CrowdRun, OldTownCross, Seeking
        [3, 4],  # ElFuente: ElFuente1, ElFuente2
        [8],  # sports: Tennis
    ]
    param_filepath = VmafConfig.resource_path('param', 'vmaf_v3.py')

    aggregate_method = np.mean
    # aggregate_method = ListStats.harmonic_mean
    # aggregate_method = partial(ListStats.lp_norm, p=2.0)

    run_vmaf_kfold_cv(
        dataset_filepath=nflx_dataset_path,
        contentid_groups=contentid_groups,
        param_filepath=param_filepath,
        aggregate_method=aggregate_method,
    )

    DisplayConfig.show()

    print 'Done.'
Пример #10
0
def main():

    if len(sys.argv) < 5:
        print_usage()
        return 2

    try:
        q_width = int(sys.argv[1])
        q_height = int(sys.argv[2])
        ref_file = sys.argv[3]
        dis_file = sys.argv[4]
    except ValueError:
        print_usage()
        return 2

    if q_width < 0 or q_height < 0:
        print "quality_width and quality_height must be non-negative, but are {w} and {h}".format(
            w=q_width, h=q_height)
        print_usage()
        return 2

    model_path = get_cmd_option(sys.argv, 5, len(sys.argv), '--model')

    thread_no = get_cmd_option(sys.argv, 5, len(sys.argv), '--thread')

    if thread_no is not None:
        print "Number of VMAF threads " + thread_no
        try:
            thread_no = int(thread_no)
        except:
            pass
        if not (isinstance(thread_no, int) and
                (thread_no >= 0 and thread_no <= 36)):
            print "Error in number of threads, make sure it's an integer between 0 and 36"
            return 2
    else:
        print "Using maximum VMAF threads"

    out_fmt = get_cmd_option(sys.argv, 5, len(sys.argv), '--out-fmt')
    if not (out_fmt is None or out_fmt == 'xml' or out_fmt == 'json'
            or out_fmt == 'text'):
        print_usage()
        return 2

    ref_fmt = get_cmd_option(sys.argv, 5, len(sys.argv), '--ref-fmt')
    if not (ref_fmt is None or ref_fmt in FMTS):
        print '--ref-fmt can only have option among {}'.format(', '.join(FMTS))

    ref_width = get_cmd_option(sys.argv, 5, len(sys.argv), '--ref-width')
    ref_height = get_cmd_option(sys.argv, 5, len(sys.argv), '--ref-height')
    dis_width = get_cmd_option(sys.argv, 5, len(sys.argv), '--dis-width')
    dis_height = get_cmd_option(sys.argv, 5, len(sys.argv), '--dis-height')

    dis_fmt = get_cmd_option(sys.argv, 5, len(sys.argv), '--dis-fmt')
    if not (dis_fmt is None or dis_fmt in FMTS):
        print '--dis-fmt can only have option among {}'.format(', '.join(FMTS))

    work_dir = get_cmd_option(sys.argv, 5, len(sys.argv), '--work-dir')

    pool_method = get_cmd_option(sys.argv, 5, len(sys.argv), '--pool')
    if not (pool_method is None or pool_method in POOL_METHODS):
        print '--pool can only have option among {}'.format(
            ', '.join(POOL_METHODS))
        return 2

    show_local_explanation = cmd_option_exists(sys.argv, 5, len(sys.argv),
                                               '--local-explain')

    phone_model = cmd_option_exists(sys.argv, 5, len(sys.argv),
                                    '--phone-model')

    enable_conf_interval = cmd_option_exists(sys.argv, 5, len(sys.argv),
                                             '--ci')

    save_plot_dir = get_cmd_option(sys.argv, 5, len(sys.argv), '--save-plot')

    if work_dir is None:
        work_dir = VmafConfig.workdir_path()

    asset_dict = {'quality_width': q_width, 'quality_height': q_height}

    if ref_fmt is None:
        asset_dict['ref_yuv_type'] = 'notyuv'
    else:
        if ref_width is None or ref_height is None:
            print 'if --ref-fmt is specified, both --ref-width and --ref-height must be specified'
            return 2
        else:
            asset_dict['ref_yuv_type'] = ref_fmt
            asset_dict['ref_width'] = ref_width
            asset_dict['ref_height'] = ref_height

    if dis_fmt is None:
        asset_dict['dis_yuv_type'] = 'notyuv'
    else:
        if dis_width is None or dis_height is None:
            print 'if --dis-fmt is specified, both --dis-width and --dis-height must be specified'
            return 2
        else:
            asset_dict['dis_yuv_type'] = dis_fmt
            asset_dict['dis_width'] = dis_width
            asset_dict['dis_height'] = dis_height

    if show_local_explanation and enable_conf_interval:
        print 'cannot set both --local-explain and --ci flags'
        return 2

    asset = Asset(
        dataset="cmd",
        content_id=abs(hash(get_file_name_without_extension(ref_file))) %
        (10**16),
        asset_id=abs(hash(get_file_name_without_extension(ref_file))) %
        (10**16),
        workdir_root=work_dir,
        ref_path=ref_file,
        dis_path=dis_file,
        asset_dict=asset_dict,
    )
    assets = [asset]

    if show_local_explanation:
        from vmaf.core.quality_runner_extra import VmafQualityRunnerWithLocalExplainer
        runner_class = VmafQualityRunnerWithLocalExplainer
    elif enable_conf_interval:
        from vmaf.core.quality_runner import BootstrapVmafQualityRunner
        runner_class = BootstrapVmafQualityRunner
    else:
        runner_class = VmafossExecQualityRunner

    if model_path is None:
        optional_dict = None
    else:
        optional_dict = {'model_filepath': model_path}

    if phone_model:
        if optional_dict is None:
            optional_dict = {}
        optional_dict['enable_transform_score'] = True

    if thread_no:
        if optional_dict is None:
            optional_dict = {}
        optional_dict['thread'] = thread_no

    runner = runner_class(
        assets,
        None,
        fifo_mode=True,
        delete_workdir=True,
        result_store=None,
        optional_dict=optional_dict,
        optional_dict2=None,
    )

    # run
    runner.run()
    result = runner.results[0]

    # pooling
    if pool_method == 'harmonic_mean':
        result.set_score_aggregate_method(ListStats.harmonic_mean)
    elif pool_method == 'min':
        result.set_score_aggregate_method(np.min)
    elif pool_method == 'median':
        result.set_score_aggregate_method(np.median)
    elif pool_method == 'perc5':
        result.set_score_aggregate_method(ListStats.perc5)
    elif pool_method == 'perc10':
        result.set_score_aggregate_method(ListStats.perc10)
    elif pool_method == 'perc20':
        result.set_score_aggregate_method(ListStats.perc20)
    else:  # None or 'mean'
        pass

    # output
    if out_fmt == 'xml':
        print result.to_xml()
    elif out_fmt == 'json':
        print result.to_json()
    else:  # None or 'text'
        print str(result)

    # local explanation
    if show_local_explanation:
        import matplotlib.pyplot as plt
        runner.show_local_explanations([result])

        if save_plot_dir is None:
            DisplayConfig.show()
        else:
            DisplayConfig.show(write_to_dir=save_plot_dir)

    return 0