Exemplo n.º 1
0
def main():
    if len(sys.argv) < 3:
        print_usage()
        return 2

    try:
        quality_type = sys.argv[1]
        test_dataset_filepath = sys.argv[2]
    except ValueError:
        print_usage()
        return 2

    vmaf_model_path = get_cmd_option(sys.argv, 3, len(sys.argv),
                                     '--vmaf-model')
    cache_result = cmd_option_exists(sys.argv, 3, len(sys.argv),
                                     '--cache-result')
    parallelize = cmd_option_exists(sys.argv, 3, len(sys.argv),
                                    '--parallelize')
    print_result = cmd_option_exists(sys.argv, 3, len(sys.argv),
                                     '--print-result')
    suppress_plot = cmd_option_exists(sys.argv, 3, len(sys.argv),
                                      '--suppress-plot')
    vmaf_phone_model = cmd_option_exists(sys.argv, 3, len(sys.argv),
                                         '--vmaf-phone-model')

    pool_method = get_cmd_option(sys.argv, 3, len(sys.argv), '--pool')
    if not (pool_method is None or pool_method in POOL_METHODS):
        print '--pool can only have option among {}'.format(
            ', '.join(POOL_METHODS))
        return 2

    subj_model = get_cmd_option(sys.argv, 3, len(sys.argv), '--subj-model')

    try:
        if subj_model is not None:
            from sureal.subjective_model import SubjectiveModel
            subj_model_class = SubjectiveModel.find_subclass(subj_model)
        else:
            subj_model_class = None
    except Exception as e:
        print "Error: " + str(e)
        return 1

    save_plot_dir = get_cmd_option(sys.argv, 3, len(sys.argv), '--save-plot')

    try:
        runner_class = QualityRunner.find_subclass(quality_type)
    except Exception as e:
        print "Error: " + str(e)
        return 1

    if vmaf_model_path is not None and runner_class != VmafQualityRunner and \
                    runner_class != BootstrapVmafQualityRunner:
        print "Input error: only quality_type of VMAF accepts --vmaf-model."
        print_usage()
        return 2

    if vmaf_phone_model and runner_class != VmafQualityRunner and \
                    runner_class != BootstrapVmafQualityRunner:
        print "Input error: only quality_type of VMAF accepts --vmaf-phone-model."
        print_usage()
        return 2

    try:
        test_dataset = import_python_file(test_dataset_filepath)
    except Exception as e:
        print "Error: " + str(e)
        return 1

    if cache_result:
        result_store = FileSystemResultStore()
    else:
        result_store = None

    # pooling
    if pool_method == 'harmonic_mean':
        aggregate_method = ListStats.harmonic_mean
    elif pool_method == 'min':
        aggregate_method = np.min
    elif pool_method == 'median':
        aggregate_method = np.median
    elif pool_method == 'perc5':
        aggregate_method = ListStats.perc5
    elif pool_method == 'perc10':
        aggregate_method = ListStats.perc10
    elif pool_method == 'perc20':
        aggregate_method = ListStats.perc20
    else:  # None or 'mean'
        aggregate_method = np.mean

    if vmaf_phone_model:
        enable_transform_score = True
    else:
        enable_transform_score = None

    try:
        if suppress_plot:
            raise AssertionError

        import matplotlib.pyplot as plt
        fig, ax = plt.subplots(figsize=(5, 5), nrows=1, ncols=1)

        assets, results = run_test_on_dataset(
            test_dataset,
            runner_class,
            ax,
            result_store,
            vmaf_model_path,
            parallelize=parallelize,
            aggregate_method=aggregate_method,
            subj_model_class=subj_model_class,
            enable_transform_score=enable_transform_score)

        bbox = {'facecolor': 'white', 'alpha': 0.5, 'pad': 20}
        ax.annotate('Testing Set',
                    xy=(0.1, 0.85),
                    xycoords='axes fraction',
                    bbox=bbox)

        # ax.set_xlim([-10, 110])
        # ax.set_ylim([-10, 110])

        plt.tight_layout()

        if save_plot_dir is None:
            DisplayConfig.show()
        else:
            DisplayConfig.show(write_to_dir=save_plot_dir)

    except ImportError:
        print_matplotlib_warning()
        assets, results = run_test_on_dataset(
            test_dataset,
            runner_class,
            None,
            result_store,
            vmaf_model_path,
            parallelize=parallelize,
            aggregate_method=aggregate_method,
            subj_model_class=subj_model_class,
            enable_transform_score=enable_transform_score)
    except AssertionError:
        assets, results = run_test_on_dataset(
            test_dataset,
            runner_class,
            None,
            result_store,
            vmaf_model_path,
            parallelize=parallelize,
            aggregate_method=aggregate_method,
            subj_model_class=subj_model_class,
            enable_transform_score=enable_transform_score)

    if print_result:
        for result in results:
            print result
            print ''

    return 0
Exemplo n.º 2
0
def main():
    if len(sys.argv) < 6:
        print_usage()
        return 2

    try:
        fmt = sys.argv[1]
        width = int(sys.argv[2])
        height = int(sys.argv[3])
        ref_file = sys.argv[4]
        dis_file = sys.argv[5]
    except ValueError:
        print_usage()
        return 2

    if width < 0 or height < 0:
        print "width and height must be non-negative, but are {w} and {h}".format(
            w=width, h=height)
        print_usage()
        return 2

    if fmt not in FMTS:
        print_usage()
        return 2

    model_path = get_cmd_option(sys.argv, 6, len(sys.argv), '--model')

    out_fmt = get_cmd_option(sys.argv, 6, len(sys.argv), '--out-fmt')
    if not (out_fmt is None or out_fmt == 'xml' or out_fmt == 'json'
            or out_fmt == 'text'):
        print_usage()
        return 2

    pool_method = get_cmd_option(sys.argv, 6, len(sys.argv), '--pool')
    if not (pool_method is None or pool_method in POOL_METHODS):
        print '--pool can only have option among {}'.format(
            ', '.join(POOL_METHODS))
        return 2

    show_local_explanation = cmd_option_exists(sys.argv, 6, len(sys.argv),
                                               '--local-explain')

    phone_model = cmd_option_exists(sys.argv, 6, len(sys.argv),
                                    '--phone-model')

    asset = Asset(
        dataset="cmd",
        content_id=abs(hash(get_file_name_without_extension(ref_file))) %
        (10**16),
        asset_id=abs(hash(get_file_name_without_extension(ref_file))) %
        (10**16),
        workdir_root=VmafConfig.workdir_path(),
        ref_path=ref_file,
        dis_path=dis_file,
        asset_dict={
            'width': width,
            'height': height,
            'yuv_type': fmt
        })
    assets = [asset]

    if not show_local_explanation:
        runner_class = VmafQualityRunner
    else:
        from vmaf.core.quality_runner_extra import VmafQualityRunnerWithLocalExplainer
        runner_class = VmafQualityRunnerWithLocalExplainer

    if model_path is None:
        optional_dict = None
    else:
        optional_dict = {'model_filepath': model_path}

    if phone_model:
        if optional_dict is None:
            optional_dict = {}
        optional_dict['enable_transform_score'] = True

    runner = runner_class(
        assets,
        None,
        fifo_mode=True,
        delete_workdir=True,
        result_store=None,
        optional_dict=optional_dict,
        optional_dict2=None,
    )

    # run
    runner.run()
    result = runner.results[0]

    # pooling
    if pool_method == 'harmonic_mean':
        result.set_score_aggregate_method(ListStats.harmonic_mean)
    elif pool_method == 'min':
        result.set_score_aggregate_method(np.min)
    elif pool_method == 'median':
        result.set_score_aggregate_method(np.median)
    elif pool_method == 'perc5':
        result.set_score_aggregate_method(ListStats.perc5)
    elif pool_method == 'perc10':
        result.set_score_aggregate_method(ListStats.perc10)
    elif pool_method == 'perc20':
        result.set_score_aggregate_method(ListStats.perc20)
    else:  # None or 'mean'
        pass

    # output
    if out_fmt == 'xml':
        print result.to_xml()
    elif out_fmt == 'json':
        print result.to_json()
    else:  # None or 'text'
        print str(result)

    # local explanation
    if show_local_explanation:
        import matplotlib.pyplot as plt
        runner.show_local_explanations([result])
        plt.show()

    return 0
Exemplo n.º 3
0
def main():
    if len(sys.argv) < 2:
        print_usage()
        return 2

    input_filepath = sys.argv[1]

    model_path = get_cmd_option(sys.argv, 2, len(sys.argv), '--model')

    out_fmt = get_cmd_option(sys.argv, 2, len(sys.argv), '--out-fmt')
    if not (out_fmt is None or out_fmt == 'xml' or out_fmt == 'json'
            or out_fmt == 'text'):
        print_usage()
        return 2

    pool_method = get_cmd_option(sys.argv, 2, len(sys.argv), '--pool')
    if not (pool_method is None or pool_method in POOL_METHODS):
        print('--pool can only have option among {}'.format(
            ', '.join(POOL_METHODS)))
        return 2

    parallelize = cmd_option_exists(sys.argv, 2, len(sys.argv),
                                    '--parallelize')

    phone_model = cmd_option_exists(sys.argv, 2, len(sys.argv),
                                    '--phone-model')

    enable_conf_interval = cmd_option_exists(sys.argv, 2, len(sys.argv),
                                             '--ci')

    assets = []
    line_idx = 0
    with open(input_filepath, "rt") as input_file:
        for line in input_file.readlines():

            # match comment
            mo = re.match(r"^#", line)
            if mo:
                print("Skip commented line: {}".format(line))
                continue

            # match whitespace
            mo = re.match(r"[\s]+", line)
            if mo:
                continue

            # example: yuv420p 576 324 ref.yuv dis.yuv
            mo = re.match(r"([\S]+) ([0-9]+) ([0-9]+) ([\S]+) ([\S]+)", line)
            if not mo or mo.group(1) not in FMTS:
                print("Unknown format: {}".format(line))
                print_usage()
                return 1

            fmt = mo.group(1)
            width = int(mo.group(2))
            height = int(mo.group(3))
            ref_file = mo.group(4)
            dis_file = mo.group(5)

            asset = Asset(dataset="cmd",
                          content_id=0,
                          asset_id=line_idx,
                          workdir_root=VmafConfig.workdir_path(),
                          ref_path=ref_file,
                          dis_path=dis_file,
                          asset_dict={
                              'width': width,
                              'height': height,
                              'yuv_type': fmt
                          })
            assets.append(asset)
            line_idx += 1

    if enable_conf_interval:
        from vmaf.core.quality_runner import BootstrapVmafQualityRunner
        runner_class = BootstrapVmafQualityRunner
    else:
        runner_class = VmafQualityRunner

    if model_path is None:
        optional_dict = None
    else:
        optional_dict = {'model_filepath': model_path}

    if phone_model:
        if optional_dict is None:
            optional_dict = {}
        optional_dict['enable_transform_score'] = True

    runner = runner_class(
        assets,
        None,
        fifo_mode=True,
        delete_workdir=True,
        result_store=None,
        optional_dict=optional_dict,
        optional_dict2=None,
    )
    runner.run(parallelize=parallelize)
    results = runner.results

    # output
    for result in results:

        # pooling
        if pool_method == 'harmonic_mean':
            result.set_score_aggregate_method(ListStats.harmonic_mean)
        elif pool_method == 'min':
            result.set_score_aggregate_method(np.min)
        elif pool_method == 'median':
            result.set_score_aggregate_method(np.median)
        elif pool_method == 'perc5':
            result.set_score_aggregate_method(ListStats.perc5)
        elif pool_method == 'perc10':
            result.set_score_aggregate_method(ListStats.perc10)
        elif pool_method == 'perc20':
            result.set_score_aggregate_method(ListStats.perc20)
        else:  # None or 'mean'
            pass

        if out_fmt == 'xml':
            print(result.to_xml())
        elif out_fmt == 'json':
            print(result.to_json())
        else:  # None or 'json'
            print('============================')
            print('Asset {asset_id}:'.format(asset_id=result.asset.asset_id))
            print('============================')
            print(result)

    return 0
Exemplo n.º 4
0
def main():

    if len(sys.argv) < 5:
        print_usage()
        return 2

    try:
        train_dataset_filepath = sys.argv[1]
        feature_param_filepath = sys.argv[2]
        model_param_filepath = sys.argv[3]
        output_model_filepath = sys.argv[4]
    except ValueError:
        print_usage()
        return 2

    try:
        train_dataset = import_python_file(train_dataset_filepath)
        feature_param = import_python_file(feature_param_filepath)
        model_param = import_python_file(model_param_filepath)
    except Exception as e:
        print("Error: %s" % e)
        return 1

    cache_result = cmd_option_exists(sys.argv, 3, len(sys.argv),
                                     '--cache-result')
    parallelize = cmd_option_exists(sys.argv, 3, len(sys.argv),
                                    '--parallelize')
    suppress_plot = cmd_option_exists(sys.argv, 3, len(sys.argv),
                                      '--suppress-plot')

    pool_method = get_cmd_option(sys.argv, 3, len(sys.argv), '--pool')
    if not (pool_method is None or pool_method in POOL_METHODS):
        print('--pool can only have option among {}'.format(
            ', '.join(POOL_METHODS)))
        return 2

    subj_model = get_cmd_option(sys.argv, 3, len(sys.argv), '--subj-model')

    try:
        if subj_model is not None:
            from sureal.subjective_model import SubjectiveModel
            subj_model_class = SubjectiveModel.find_subclass(subj_model)
        else:
            subj_model_class = None
    except Exception as e:
        print("Error: %s" % e)
        return 1

    save_plot_dir = get_cmd_option(sys.argv, 3, len(sys.argv), '--save-plot')

    if cache_result:
        result_store = FileSystemResultStore()
    else:
        result_store = None

    # pooling
    if pool_method == 'harmonic_mean':
        aggregate_method = ListStats.harmonic_mean
    elif pool_method == 'min':
        aggregate_method = np.min
    elif pool_method == 'median':
        aggregate_method = np.median
    elif pool_method == 'perc5':
        aggregate_method = ListStats.perc5
    elif pool_method == 'perc10':
        aggregate_method = ListStats.perc10
    elif pool_method == 'perc20':
        aggregate_method = ListStats.perc20
    else:  # None or 'mean'
        aggregate_method = np.mean

    logger = None

    try:
        if suppress_plot:
            raise AssertionError

        from vmaf import plt
        fig, ax = plt.subplots(figsize=(5, 5), nrows=1, ncols=1)

        train_test_vmaf_on_dataset(
            train_dataset=train_dataset,
            test_dataset=None,
            feature_param=feature_param,
            model_param=model_param,
            train_ax=ax,
            test_ax=None,
            result_store=result_store,
            parallelize=parallelize,
            logger=logger,
            output_model_filepath=output_model_filepath,
            aggregate_method=aggregate_method,
            subj_model_class=subj_model_class,
        )

        bbox = {'facecolor': 'white', 'alpha': 0.5, 'pad': 20}
        ax.annotate('Training Set',
                    xy=(0.1, 0.85),
                    xycoords='axes fraction',
                    bbox=bbox)

        # ax.set_xlim([-10, 110])
        # ax.set_ylim([-10, 110])

        plt.tight_layout()

        if save_plot_dir is None:
            DisplayConfig.show()
        else:
            DisplayConfig.show(write_to_dir=save_plot_dir)

    except ImportError:
        print_matplotlib_warning()
        train_test_vmaf_on_dataset(
            train_dataset=train_dataset,
            test_dataset=None,
            feature_param=feature_param,
            model_param=model_param,
            train_ax=None,
            test_ax=None,
            result_store=result_store,
            parallelize=parallelize,
            logger=logger,
            output_model_filepath=output_model_filepath,
            aggregate_method=aggregate_method,
            subj_model_class=subj_model_class,
        )
    except AssertionError:
        train_test_vmaf_on_dataset(
            train_dataset=train_dataset,
            test_dataset=None,
            feature_param=feature_param,
            model_param=model_param,
            train_ax=None,
            test_ax=None,
            result_store=result_store,
            parallelize=parallelize,
            logger=logger,
            output_model_filepath=output_model_filepath,
            aggregate_method=aggregate_method,
            subj_model_class=subj_model_class,
        )

    return 0
Exemplo n.º 5
0
def main():
    if len(sys.argv) < 5:
        print_usage()
        return 2

    try:
        q_width = int(sys.argv[1])
        q_height = int(sys.argv[2])
        ref_file = sys.argv[3]
        dis_file = sys.argv[4]
    except ValueError:
        print_usage()
        return 2

    if q_width < 0 or q_height < 0:
        print(
            "quality_width and quality_height must be non-negative, but are {w} and {h}"
            .format(w=q_width, h=q_height))
        print_usage()
        return 2

    model_path = get_cmd_option(sys.argv, 5, len(sys.argv), '--model')

    out_fmt = get_cmd_option(sys.argv, 5, len(sys.argv), '--out-fmt')
    if not (out_fmt is None or out_fmt == 'xml' or out_fmt == 'json'
            or out_fmt == 'text'):
        print_usage()
        return 2

    ref_fmt = get_cmd_option(sys.argv, 5, len(sys.argv), '--ref-fmt')
    if not (ref_fmt is None or ref_fmt in FMTS):
        print('--ref-fmt can only have option among {}'.format(
            ', '.join(FMTS)))

    ref_width = get_cmd_option(sys.argv, 5, len(sys.argv), '--ref-width')
    ref_height = get_cmd_option(sys.argv, 5, len(sys.argv), '--ref-height')
    dis_width = get_cmd_option(sys.argv, 5, len(sys.argv), '--dis-width')
    dis_height = get_cmd_option(sys.argv, 5, len(sys.argv), '--dis-height')

    dis_fmt = get_cmd_option(sys.argv, 5, len(sys.argv), '--dis-fmt')
    if not (dis_fmt is None or dis_fmt in FMTS):
        print('--dis-fmt can only have option among {}'.format(
            ', '.join(FMTS)))

    work_dir = get_cmd_option(sys.argv, 5, len(sys.argv), '--work-dir')

    pool_method = get_cmd_option(sys.argv, 5, len(sys.argv), '--pool')
    if not (pool_method is None or pool_method in POOL_METHODS):
        print('--pool can only have option among {}'.format(
            ', '.join(POOL_METHODS)))
        return 2

    show_local_explanation = cmd_option_exists(sys.argv, 5, len(sys.argv),
                                               '--local-explain')

    phone_model = cmd_option_exists(sys.argv, 5, len(sys.argv),
                                    '--phone-model')

    enable_conf_interval = cmd_option_exists(sys.argv, 5, len(sys.argv),
                                             '--ci')

    save_plot_dir = get_cmd_option(sys.argv, 5, len(sys.argv), '--save-plot')

    if work_dir is None:
        work_dir = VmafConfig.workdir_path()

    asset_dict = {'quality_width': q_width, 'quality_height': q_height}

    if ref_fmt is None:
        asset_dict['ref_yuv_type'] = 'notyuv'
    else:
        if ref_width is None or ref_height is None:
            print(
                'if --ref-fmt is specified, both --ref-width and --ref-height must be specified'
            )
            return 2
        else:
            asset_dict['ref_yuv_type'] = ref_fmt
            asset_dict['ref_width'] = ref_width
            asset_dict['ref_height'] = ref_height

    if dis_fmt is None:
        asset_dict['dis_yuv_type'] = 'notyuv'
    else:
        if dis_width is None or dis_height is None:
            print(
                'if --dis-fmt is specified, both --dis-width and --dis-height must be specified'
            )
            return 2
        else:
            asset_dict['dis_yuv_type'] = dis_fmt
            asset_dict['dis_width'] = dis_width
            asset_dict['dis_height'] = dis_height

    if show_local_explanation and enable_conf_interval:
        print('cannot set both --local-explain and --ci flags')
        return 2

    asset = Asset(
        dataset="cmd",
        content_id=abs(hash(get_file_name_without_extension(ref_file))) %
        (10**16),
        asset_id=abs(hash(get_file_name_without_extension(ref_file))) %
        (10**16),
        workdir_root=work_dir,
        ref_path=ref_file,
        dis_path=dis_file,
        asset_dict=asset_dict,
    )
    assets = [asset]

    if show_local_explanation:
        from vmaf.core.quality_runner_extra import VmafQualityRunnerWithLocalExplainer
        runner_class = VmafQualityRunnerWithLocalExplainer
    elif enable_conf_interval:
        from vmaf.core.quality_runner import BootstrapVmafQualityRunner
        runner_class = BootstrapVmafQualityRunner
    else:
        runner_class = VmafQualityRunner

    if model_path is None:
        optional_dict = None
    else:
        optional_dict = {'model_filepath': model_path}

    if phone_model:
        if optional_dict is None:
            optional_dict = {}
        optional_dict['enable_transform_score'] = True

    runner = runner_class(
        assets,
        None,
        fifo_mode=True,
        delete_workdir=True,
        result_store=None,
        optional_dict=optional_dict,
        optional_dict2=None,
    )

    # run
    runner.run()
    result = runner.results[0]

    # pooling
    if pool_method == 'harmonic_mean':
        result.set_score_aggregate_method(ListStats.harmonic_mean)
    elif pool_method == 'min':
        result.set_score_aggregate_method(np.min)
    elif pool_method == 'median':
        result.set_score_aggregate_method(np.median)
    elif pool_method == 'perc5':
        result.set_score_aggregate_method(ListStats.perc5)
    elif pool_method == 'perc10':
        result.set_score_aggregate_method(ListStats.perc10)
    elif pool_method == 'perc20':
        result.set_score_aggregate_method(ListStats.perc20)
    else:  # None or 'mean'
        pass

    # output
    if out_fmt == 'xml':
        print(result.to_xml())
    elif out_fmt == 'json':
        print(result.to_json())
    else:  # None or 'text'
        print(str(result))

    # local explanation
    if show_local_explanation:
        runner.show_local_explanations([result])

        if save_plot_dir is None:
            DisplayConfig.show()
        else:
            DisplayConfig.show(write_to_dir=save_plot_dir)

    return 0
Exemplo n.º 6
0
def main():
    if len(sys.argv) < 6:
        print_usage()
        return 2

    try:
        fmt = sys.argv[1]
        width = int(sys.argv[2])
        height = int(sys.argv[3])
        ref_path = sys.argv[4]
        dis_path = sys.argv[5]
    except ValueError:
        print_usage()
        return 2

    if width < 0 or height < 0:
        print "width and height must be non-negative, but are {w} and {h}".format(w=width, h=height)
        print_usage()
        return 2

    if fmt not in FMTS:
        print_usage()
        return 2

    out_fmt = get_cmd_option(sys.argv, 6, len(sys.argv), '--out-fmt')
    if not (out_fmt is None
            or out_fmt == 'xml'
            or out_fmt == 'json'
            or out_fmt == 'text'):
        print_usage()
        return 2

    pool_method = get_cmd_option(sys.argv, 6, len(sys.argv), '--pool')
    if not (pool_method is None
            or pool_method in POOL_METHODS):
        print '--pool can only have option among {}'.format(', '.join(POOL_METHODS))
        return 2

    asset = Asset(dataset="cmd", content_id=0, asset_id=0,
                  workdir_root=VmafConfig.workdir_path(),
                  ref_path=ref_path,
                  dis_path=dis_path,
                  asset_dict={'width':width, 'height':height, 'yuv_type':fmt}
                  )
    assets = [asset]

    runner_class = PsnrQualityRunner

    runner = runner_class(
        assets, None, fifo_mode=True,
        delete_workdir=True,
        result_store=None,
        optional_dict=None,
        optional_dict2=None,
    )

    # run
    runner.run()
    result = runner.results[0]

    # pooling
    if pool_method == 'harmonic_mean':
        result.set_score_aggregate_method(ListStats.harmonic_mean)
    elif pool_method == 'min':
        result.set_score_aggregate_method(np.min)
    elif pool_method == 'median':
        result.set_score_aggregate_method(np.median)
    elif pool_method == 'perc5':
        result.set_score_aggregate_method(ListStats.perc5)
    elif pool_method == 'perc10':
        result.set_score_aggregate_method(ListStats.perc10)
    elif pool_method == 'perc20':
        result.set_score_aggregate_method(ListStats.perc20)
    else: # None or 'mean'
        pass

    # output
    if out_fmt == 'xml':
        print result.to_xml()
    elif out_fmt == 'json':
        print result.to_json()
    else: # None or 'text'
        print str(result)

    return 0
Exemplo n.º 7
0
def main():
    if len(sys.argv) < 5:
        print_usage()
        return 2

    try:
        q_width = int(sys.argv[1])
        q_height = int(sys.argv[2])
        ref_file = sys.argv[3]
        dis_file = sys.argv[4]
    except ValueError:
        print_usage()
        return 2

    if q_width < 0 or q_height < 0:
        print "quality_width and quality_height must be non-negative, but are {w} and {h}".format(w=q_width, h=q_height)
        print_usage()
        return 2

    model_path = get_cmd_option(sys.argv, 5, len(sys.argv), '--model')

    out_fmt = get_cmd_option(sys.argv, 5, len(sys.argv), '--out-fmt')
    if not (out_fmt is None
            or out_fmt == 'xml'
            or out_fmt == 'json'
            or out_fmt == 'text'):
        print_usage()
        return 2

    work_dir = get_cmd_option(sys.argv, 5, len(sys.argv), '--work-dir')

    pool_method = get_cmd_option(sys.argv, 5, len(sys.argv), '--pool')
    if not (pool_method is None
            or pool_method in POOL_METHODS):
        print '--pool can only have option among {}'.format(', '.join(POOL_METHODS))
        return 2

    show_local_explanation = cmd_option_exists(sys.argv, 5, len(sys.argv), '--local-explain')

    phone_model = cmd_option_exists(sys.argv, 5, len(sys.argv), '--phone-model')

    if work_dir is None:
        work_dir = VmafConfig.workdir_path()

    asset = Asset(dataset="cmd",
                  content_id=abs(hash(get_file_name_without_extension(ref_file))) % (10 ** 16),
                  asset_id=abs(hash(get_file_name_without_extension(ref_file))) % (10 ** 16),
                  workdir_root=work_dir,
                  ref_path=ref_file,
                  dis_path=dis_file,
                  asset_dict={'quality_width':q_width, 'quality_height':q_height, 'yuv_type': 'notyuv'}
                  )
    assets = [asset]

    if not show_local_explanation:
        runner_class = VmafQualityRunner
    else:
        runner_class = VmafQualityRunnerWithLocalExplainer

    if model_path is None:
        optional_dict = None
    else:
        optional_dict = {'model_filepath':model_path}

    if phone_model:
        if optional_dict is None:
            optional_dict = {}
        optional_dict['enable_transform_score'] = True

    runner = runner_class(
        assets, None, fifo_mode=True,
        delete_workdir=True,
        result_store=None,
        optional_dict=optional_dict,
        optional_dict2=None,
    )

    # run
    runner.run()
    result = runner.results[0]

    # pooling
    if pool_method == 'harmonic_mean':
        result.set_score_aggregate_method(ListStats.harmonic_mean)
    elif pool_method == 'min':
        result.set_score_aggregate_method(np.min)
    elif pool_method == 'median':
        result.set_score_aggregate_method(np.median)
    elif pool_method == 'perc5':
        result.set_score_aggregate_method(ListStats.perc5)
    elif pool_method == 'perc10':
        result.set_score_aggregate_method(ListStats.perc10)
    elif pool_method == 'perc20':
        result.set_score_aggregate_method(ListStats.perc20)
    else: # None or 'mean'
        pass

    # output
    if out_fmt == 'xml':
        print result.to_xml()
    elif out_fmt == 'json':
        print result.to_json()
    else: # None or 'text'
        print str(result)

    # local explanation
    if show_local_explanation:
        import matplotlib.pyplot as plt
        runner.show_local_explanations([result])
        plt.show()

    return 0
Exemplo n.º 8
0
def main():
    if len(sys.argv) < 6:
        print_usage()
        return 2

    try:
        fmt = sys.argv[1]
        width = int(sys.argv[2])
        height = int(sys.argv[3])
        ref_path = sys.argv[4]
        dis_path = sys.argv[5]
    except ValueError:
        print_usage()
        return 2

    if width < 0 or height < 0:
        print "width and height must be non-negative, but are {w} and {h}".format(w=width, h=height)
        print_usage()
        return 2

    if fmt not in FMTS:
        print_usage()
        return 2

    out_fmt = get_cmd_option(sys.argv, 6, len(sys.argv), '--out-fmt')
    if not (out_fmt is None
            or out_fmt == 'xml'
            or out_fmt == 'json'
            or out_fmt == 'text'):
        print_usage()
        return 2

    pool_method = get_cmd_option(sys.argv, 6, len(sys.argv), '--pool')
    if not (pool_method is None
            or pool_method in POOL_METHODS):
        print '--pool can only have option among {}'.format(', '.join(POOL_METHODS))
        return 2

    asset = Asset(dataset="cmd", content_id=0, asset_id=0,
                  workdir_root=VmafConfig.workdir_path(),
                  ref_path=ref_path,
                  dis_path=dis_path,
                  asset_dict={'width':width, 'height':height, 'yuv_type':fmt}
                  )
    assets = [asset]

    runner_class = PsnrQualityRunner

    runner = runner_class(
        assets, None, fifo_mode=True,
        delete_workdir=True,
        result_store=None,
        optional_dict=None,
        optional_dict2=None,
    )

    # run
    runner.run()
    result = runner.results[0]

    # pooling
    if pool_method == 'harmonic_mean':
        result.set_score_aggregate_method(ListStats.harmonic_mean)
    elif pool_method == 'min':
        result.set_score_aggregate_method(np.min)
    elif pool_method == 'median':
        result.set_score_aggregate_method(np.median)
    elif pool_method == 'perc5':
        result.set_score_aggregate_method(ListStats.perc5)
    elif pool_method == 'perc10':
        result.set_score_aggregate_method(ListStats.perc10)
    elif pool_method == 'perc20':
        result.set_score_aggregate_method(ListStats.perc20)
    else: # None or 'mean'
        pass

    # output
    if out_fmt == 'xml':
        print result.to_xml()
    elif out_fmt == 'json':
        print result.to_json()
    else: # None or 'text'
        print str(result)

    return 0
Exemplo n.º 9
0
def main():
    if len(sys.argv) < 2:
        print_usage()
        return 2

    input_filepath = sys.argv[1]

    model_path = get_cmd_option(sys.argv, 2, len(sys.argv), '--model')

    out_fmt = get_cmd_option(sys.argv, 2, len(sys.argv), '--out-fmt')
    if not (out_fmt is None
            or out_fmt == 'xml'
            or out_fmt == 'json'
            or out_fmt == 'text'):
        print_usage()
        return 2

    pool_method = get_cmd_option(sys.argv, 2, len(sys.argv), '--pool')
    if not (pool_method is None
            or pool_method in POOL_METHODS):
        print '--pool can only have option among {}'.format(', '.join(POOL_METHODS))
        return 2

    parallelize = cmd_option_exists(sys.argv, 2, len(sys.argv), '--parallelize')

    phone_model = cmd_option_exists(sys.argv, 2, len(sys.argv), '--phone-model')

    assets = []
    line_idx = 0
    with open(input_filepath, "rt") as input_file:
        for line in input_file.readlines():

            # match comment
            mo = re.match(r"^#", line)
            if mo:
                print "Skip commented line: {}".format(line)
                continue

            # match whitespace
            mo = re.match(r"[\s]+", line)
            if mo:
                continue

            # example: yuv420p 576 324 ref.yuv dis.yuv
            mo = re.match(r"([\S]+) ([0-9]+) ([0-9]+) ([\S]+) ([\S]+)", line)
            if not mo or mo.group(1) not in FMTS:
                print "Unknown format: {}".format(line)
                print_usage()
                return 1

            fmt = mo.group(1)
            width = int(mo.group(2))
            height = int(mo.group(3))
            ref_file = mo.group(4)
            dis_file = mo.group(5)

            asset = Asset(dataset="cmd",
                          content_id=0,
                          asset_id=line_idx,
                          workdir_root=VmafConfig.workdir_path(),
                          ref_path=ref_file,
                          dis_path=dis_file,
                          asset_dict={'width':width, 'height':height, 'yuv_type':fmt}
                          )
            assets.append(asset)
            line_idx += 1

    runner_class = VmafQualityRunner

    if model_path is None:
        optional_dict = None
    else:
        optional_dict = {'model_filepath':model_path}

    if phone_model:
        if optional_dict is None:
            optional_dict = {}
        optional_dict['enable_transform_score'] = True

    runner = runner_class(
        assets,
        None, fifo_mode=True,
        delete_workdir=True,
        result_store=None,
        optional_dict=optional_dict,
        optional_dict2=None,
    )
    runner.run(parallelize=parallelize)
    results = runner.results

    # output
    for result in results:

        # pooling
        if pool_method == 'harmonic_mean':
            result.set_score_aggregate_method(ListStats.harmonic_mean)
        elif pool_method == 'min':
            result.set_score_aggregate_method(np.min)
        elif pool_method == 'median':
            result.set_score_aggregate_method(np.median)
        elif pool_method == 'perc5':
            result.set_score_aggregate_method(ListStats.perc5)
        elif pool_method == 'perc10':
            result.set_score_aggregate_method(ListStats.perc10)
        elif pool_method == 'perc20':
            result.set_score_aggregate_method(ListStats.perc20)
        else: # None or 'mean'
            pass

        if out_fmt == 'xml':
            print result.to_xml()
        elif out_fmt == 'json':
            print result.to_json()
        else: # None or 'json'
            print '============================'
            print 'Asset {asset_id}:'.format(asset_id=result.asset.asset_id)
            print '============================'
            print str(result)

    return 0
Exemplo n.º 10
0
def main():

    if len(sys.argv) < 5:
        print_usage()
        return 2

    try:
        train_dataset_filepath = sys.argv[1]
        feature_param_filepath = sys.argv[2]
        model_param_filepath = sys.argv[3]
        output_model_filepath = sys.argv[4]
    except ValueError:
        print_usage()
        return 2

    try:
        train_dataset = import_python_file(train_dataset_filepath)
        feature_param = import_python_file(feature_param_filepath)
        model_param = import_python_file(model_param_filepath)
    except Exception as e:
        print "Error: " + str(e)
        return 1

    cache_result = cmd_option_exists(sys.argv, 3, len(sys.argv), '--cache-result')
    parallelize = cmd_option_exists(sys.argv, 3, len(sys.argv), '--parallelize')
    suppress_plot = cmd_option_exists(sys.argv, 3, len(sys.argv), '--suppress-plot')

    pool_method = get_cmd_option(sys.argv, 3, len(sys.argv), '--pool')
    if not (pool_method is None
            or pool_method in POOL_METHODS):
        print '--pool can only have option among {}'.format(', '.join(POOL_METHODS))
        return 2

    subj_model = get_cmd_option(sys.argv, 3, len(sys.argv), '--subj-model')

    try:
        if subj_model is not None:
            subj_model_class = SubjectiveModel.find_subclass(subj_model)
        else:
            subj_model_class = None
    except Exception as e:
        print "Error: " + str(e)
        return 1

    if cache_result:
        result_store = FileSystemResultStore()
    else:
        result_store = None

    # pooling
    if pool_method == 'harmonic_mean':
        aggregate_method = ListStats.harmonic_mean
    elif pool_method == 'min':
        aggregate_method = np.min
    elif pool_method == 'median':
        aggregate_method = np.median
    elif pool_method == 'perc5':
        aggregate_method = ListStats.perc5
    elif pool_method == 'perc10':
        aggregate_method = ListStats.perc10
    elif pool_method == 'perc20':
        aggregate_method = ListStats.perc20
    else: # None or 'mean'
        aggregate_method = np.mean

    logger = None

    try:
        if suppress_plot:
            raise AssertionError

        import matplotlib.pyplot as plt
        fig, ax = plt.subplots(figsize=(5, 5), nrows=1, ncols=1)

        train_test_vmaf_on_dataset(train_dataset=train_dataset, test_dataset=None,
                                   feature_param=feature_param, model_param=model_param,
                                   train_ax=ax, test_ax=None,
                                   result_store=result_store,
                                   parallelize=parallelize,
                                   logger=logger,
                                   output_model_filepath=output_model_filepath,
                                   aggregate_method=aggregate_method,
                                   subj_model_class=subj_model_class,
                                   )

        bbox = {'facecolor':'white', 'alpha':0.5, 'pad':20}
        ax.annotate('Training Set', xy=(0.1, 0.85), xycoords='axes fraction', bbox=bbox)

        # ax.set_xlim([-10, 110])
        # ax.set_ylim([-10, 110])

        plt.tight_layout()
        plt.show()
    except ImportError:
        print_matplotlib_warning()
        train_test_vmaf_on_dataset(train_dataset=train_dataset, test_dataset=None,
                                   feature_param=feature_param, model_param=model_param,
                                   train_ax=None, test_ax=None,
                                   result_store=result_store,
                                   parallelize=parallelize,
                                   logger=logger,
                                   output_model_filepath=output_model_filepath,
                                   aggregate_method=aggregate_method,
                                   subj_model_class=subj_model_class,
                                   )
    except AssertionError:
        train_test_vmaf_on_dataset(train_dataset=train_dataset, test_dataset=None,
                                   feature_param=feature_param, model_param=model_param,
                                   train_ax=None, test_ax=None,
                                   result_store=result_store,
                                   parallelize=parallelize,
                                   logger=logger,
                                   output_model_filepath=output_model_filepath,
                                   aggregate_method=aggregate_method,
                                   subj_model_class=subj_model_class,
                                   )

    return 0
Exemplo n.º 11
0
def main():
    if len(sys.argv) < 3:
        print_usage()
        return 2

    try:
        quality_type = sys.argv[1]
        test_dataset_filepath = sys.argv[2]
    except ValueError:
        print_usage()
        return 2

    vmaf_model_path = get_cmd_option(sys.argv, 3, len(sys.argv), '--vmaf-model')
    cache_result = cmd_option_exists(sys.argv, 3, len(sys.argv), '--cache-result')
    parallelize = cmd_option_exists(sys.argv, 3, len(sys.argv), '--parallelize')
    print_result = cmd_option_exists(sys.argv, 3, len(sys.argv), '--print-result')
    suppress_plot = cmd_option_exists(sys.argv, 3, len(sys.argv), '--suppress-plot')
    vmaf_phone_model = cmd_option_exists(sys.argv, 3, len(sys.argv), '--vmaf-phone-model')

    pool_method = get_cmd_option(sys.argv, 3, len(sys.argv), '--pool')
    if not (pool_method is None
            or pool_method in POOL_METHODS):
        print '--pool can only have option among {}'.format(', '.join(POOL_METHODS))
        return 2

    subj_model = get_cmd_option(sys.argv, 3, len(sys.argv), '--subj-model')

    try:
        if subj_model is not None:
            subj_model_class = SubjectiveModel.find_subclass(subj_model)
        else:
            subj_model_class = None
    except Exception as e:
        print "Error: " + str(e)
        return 1

    if vmaf_model_path is not None and quality_type != VmafQualityRunner.TYPE:
        print "Input error: only quality_type of VMAF accepts --vmaf-model."
        print_usage()
        return 2

    if vmaf_phone_model and quality_type != VmafQualityRunner.TYPE:
        print "Input error: only quality_type of VMAF accepts --vmaf-phone-model."
        print_usage()
        return 2

    try:
        test_dataset = import_python_file(test_dataset_filepath)
    except Exception as e:
        print "Error: " + str(e)
        return 1

    try:
        runner_class = QualityRunner.find_subclass(quality_type)
    except Exception as e:
        print "Error: " + str(e)
        return 1

    if cache_result:
        result_store = FileSystemResultStore()
    else:
        result_store = None

    # pooling
    if pool_method == 'harmonic_mean':
        aggregate_method = ListStats.harmonic_mean
    elif pool_method == 'min':
        aggregate_method = np.min
    elif pool_method == 'median':
        aggregate_method = np.median
    elif pool_method == 'perc5':
        aggregate_method = ListStats.perc5
    elif pool_method == 'perc10':
        aggregate_method = ListStats.perc10
    elif pool_method == 'perc20':
        aggregate_method = ListStats.perc20
    else: # None or 'mean'
        aggregate_method = np.mean

    if vmaf_phone_model:
        enable_transform_score = True
    else:
        enable_transform_score = None

    try:
        if suppress_plot:
            raise AssertionError

        import matplotlib.pyplot as plt
        fig, ax = plt.subplots(figsize=(5, 5), nrows=1, ncols=1)

        assets, results = run_test_on_dataset(test_dataset, runner_class, ax,
                                          result_store, vmaf_model_path,
                                          parallelize=parallelize,
                                          aggregate_method=aggregate_method,
                                          subj_model_class=subj_model_class,
                                          enable_transform_score=enable_transform_score
                                          )

        bbox = {'facecolor':'white', 'alpha':0.5, 'pad':20}
        ax.annotate('Testing Set', xy=(0.1, 0.85), xycoords='axes fraction', bbox=bbox)

        # ax.set_xlim([-10, 110])
        # ax.set_ylim([-10, 110])

        plt.tight_layout()
        plt.show()
    except ImportError:
        print_matplotlib_warning()
        assets, results = run_test_on_dataset(test_dataset, runner_class, None,
                                          result_store, vmaf_model_path,
                                          parallelize=parallelize,
                                          aggregate_method=aggregate_method,
                                          subj_model_class=subj_model_class,
                                          enable_transform_score=enable_transform_score
                                          )
    except AssertionError:
        assets, results = run_test_on_dataset(test_dataset, runner_class, None,
                                          result_store, vmaf_model_path,
                                          parallelize=parallelize,
                                          aggregate_method=aggregate_method,
                                          subj_model_class=subj_model_class,
                                          enable_transform_score=enable_transform_score
                                          )

    if print_result:
        for result in results:
            print result
            print ''

    return 0