示例#1
0
    def print_explanations(cls, exps, assets=None, ys=None, ys_pred=None):

        # asserts
        N = cls.assert_explanations(exps, assets, ys, ys_pred)

        print("Features: {}".format(exps['feature_names']))

        for n in range(N):
            weights = exps['feature_weights'][n]
            features = exps['features_normalized'][n]

            asset = assets[n] if assets is not None else None
            y = ys['label'][n] if ys is not None else None
            y_pred = ys_pred[n] if ys_pred is not None else None

            print("{ref}".format(
                ref=get_file_name_without_extension(asset.ref_path)
                if asset is not None else "Asset {}".format(n)))
            if asset is not None:
                print("\tDistorted: {dis}".format(
                    dis=get_file_name_without_extension(asset.dis_path)))
            if y is not None:
                print("\tground truth: {y:.3f}".format(y=y))
            if y_pred is not None:
                print("\tpredicted: {y_pred:.3f}".format(y_pred=y_pred))
            print("\tfeature value: {}".format(features))
            print("\tfeature weight: {}".format(weights))
示例#2
0
    def print_explanations(cls, exps, assets=None, ys=None, ys_pred=None):

        # asserts
        N = cls.assert_explanations(exps, assets, ys, ys_pred)

        print "Features: {}".format(exps['feature_names'])

        for n in range(N):
            weights = exps['feature_weights'][n]
            features = exps['features_normalized'][n]

            asset = assets[n] if assets is not None else None
            y = ys['label'][n] if ys is not None else None
            y_pred = ys_pred[n] if ys_pred is not None else None

            print "{ref}".format(
                ref=get_file_name_without_extension(asset.ref_path) if
                asset is not None else "Asset {}".format(n))
            if asset is not None:
                print "\tDistorted: {dis}".format(
                    dis=get_file_name_without_extension(asset.dis_path))
            if y is not None:
                print "\tground truth: {y:.3f}".format(y=y)
            if y_pred is not None:
                print "\tpredicted: {y_pred:.3f}".format(y_pred=y_pred)
            print "\tfeature value: {}".format(features)
            print "\tfeature weight: {}".format(weights)
示例#3
0
    def ref_str(self):
        """
        String representation for reference video.
        :return:
        """
        s = ""

        path = get_file_name_without_extension(self.ref_path)
        s += "{path}".format(path=path)

        if self.ref_width_height:
            w, h = self.ref_width_height
            s += "_{w}x{h}".format(w=w, h=h)

        if self.ref_yuv_type != self.DEFAULT_YUV_TYPE:
            s += "_{}".format(self.ref_yuv_type)

        if self.ref_start_end_frame:
            start, end = self.ref_start_end_frame
            s += "_{start}to{end}".format(start=start, end=end)

        for key in self.ORDERED_FILTER_LIST:
            if self.get_filter_cmd(key, 'ref') is not None:
                if s != "":
                    s += "_"
                s += "{}{}".format(key, self.get_filter_cmd(key, 'ref'))

        return s
示例#4
0
    def dis_str(self):
        """
        String representation for distorted video.
        :return:
        """
        s = ""

        path = get_file_name_without_extension(self.dis_path)
        s += "{path}".format(path=path)

        if self.dis_width_height:
            w, h = self.dis_width_height
            s += "_{w}x{h}".format(w=w, h=h)

        if self.dis_yuv_type != self.DEFAULT_YUV_TYPE:
            s += "_{}".format(self.dis_yuv_type)

        if self.dis_start_end_frame:
            start, end = self.dis_start_end_frame
            s += "_{start}to{end}".format(start=start, end=end)

        for key in self.ORDERED_FILTER_LIST:
            if self.get_filter_cmd(key, 'dis') is not None:
                if s != "":
                    s += "_"
                s += "{}{}".format(key, self.get_filter_cmd(key, 'dis'))

        if self.dis_proc_callback_str:
            s += f'_{self.dis_proc_callback_str}'

        return s
示例#5
0
    def dis_str(self):
        """
        String representation for distorted video.
        :return:
        """
        s = ""

        path = get_file_name_without_extension(self.dis_path)
        s += "{path}".format(path=path)

        if self.dis_width_height:
            w, h = self.dis_width_height
            s += "_{w}x{h}".format(w=w, h=h)

        if self.dis_yuv_type != self.DEFAULT_YUV_TYPE:
            s += "_{}".format(self.dis_yuv_type)

        if self.dis_start_end_frame:
            start, end = self.dis_start_end_frame
            s += "_{start}to{end}".format(start=start, end=end)

        if self.dis_crop_cmd is not None:
            if s != "":
                s += "_"
            s += "crop{}".format(self.dis_crop_cmd)

        if self.dis_pad_cmd is not None:
            if s != "":
                s += "_"
            s += "pad{}".format(self.dis_pad_cmd)

        return s
示例#6
0
 def print_assets(test_assets):
     print('\n'.join(
         map(
             lambda tasset: "Asset {i}: {name}".format(
                 i=tasset[0],
                 name=get_file_name_without_extension(tasset[1].dis_path)),
             enumerate(test_assets))))
示例#7
0
文件: asset.py 项目: rayning0/vmaf
    def dis_str(self):
        """
        String representation for distorted video.
        :return:
        """
        s = ""

        path = get_file_name_without_extension(self.dis_path)
        s += "{path}".format(path=path)

        if self.dis_width_height:
            w, h = self.dis_width_height
            s += "_{w}x{h}".format(w=w, h=h)

        if self.dis_encode_width_height != self.dis_width_height:
            w, h = self.dis_encode_width_height
            s += "_e_{w}x{h}".format(w=w, h=h)

        if self.dis_yuv_type != self.DEFAULT_YUV_TYPE:
            s += "_{}".format(self.dis_yuv_type)

        # if resolutions are consistent, no resampling is taking place, so
        # specificying resampling type should be ignored
        if self.dis_resampling_type != self.DEFAULT_RESAMPLING_TYPE and \
                not self.dis_width_height == self.quality_width_height:
            if s != "":
                s += "_"
            s += "{}".format(self.dis_resampling_type)

        if self.dis_start_end_frame:
            start, end = self.dis_start_end_frame
            s += "_{start}to{end}".format(start=start, end=end)

        for key in self.ORDERED_FILTER_LIST:
            if self.get_filter_cmd(key, 'dis') is not None:
                if s != "":
                    s += "_"
                s += "{}{}".format(key, self.get_filter_cmd(key, 'dis'))

        if self.dis_proc_callback_str:
            s += f'_{self.dis_proc_callback_str}'

        return slugify(s, separator='_')
示例#8
0
文件: asset.py 项目: xjsXjtu/vmaf
    def ref_str(self):
        """
        String representation for reference video.
        :return:
        """
        s = ""

        path = get_file_name_without_extension(self.ref_path)
        s += "{path}".format(path=path)

        if self.ref_width_height:
            w, h = self.ref_width_height
            s += "_{w}x{h}".format(w=w, h=h)

        if self.ref_yuv_type != self.DEFAULT_YUV_TYPE:
            s += "_{}".format(self.ref_yuv_type)

        if self.ref_start_end_frame:
            start, end = self.ref_start_end_frame
            s += "_{start}to{end}".format(start=start, end=end)

        return s
示例#9
0
 def print_assets(test_assets):
     print '\n'.join(
         map(
             lambda (i, asset): "Asset {i}: {name}".format(
                 i=i, name=get_file_name_without_extension(asset.dis_path)),
             enumerate(test_assets)))
示例#10
0
def main():
    if len(sys.argv) < 6:
        print_usage()
        return 2

    try:
        fmt = sys.argv[1]
        width = int(sys.argv[2])
        height = int(sys.argv[3])
        ref_file = sys.argv[4]
        dis_file = sys.argv[5]
    except ValueError:
        print_usage()
        return 2

    if width < 0 or height < 0:
        print "width and height must be non-negative, but are {w} and {h}".format(
            w=width, h=height)
        print_usage()
        return 2

    if fmt not in FMTS:
        print_usage()
        return 2

    model_path = get_cmd_option(sys.argv, 6, len(sys.argv), '--model')

    out_fmt = get_cmd_option(sys.argv, 6, len(sys.argv), '--out-fmt')
    if not (out_fmt is None or out_fmt == 'xml' or out_fmt == 'json'
            or out_fmt == 'text'):
        print_usage()
        return 2

    pool_method = get_cmd_option(sys.argv, 6, len(sys.argv), '--pool')
    if not (pool_method is None or pool_method in POOL_METHODS):
        print '--pool can only have option among {}'.format(
            ', '.join(POOL_METHODS))
        return 2

    show_local_explanation = cmd_option_exists(sys.argv, 6, len(sys.argv),
                                               '--local-explain')

    phone_model = cmd_option_exists(sys.argv, 6, len(sys.argv),
                                    '--phone-model')

    asset = Asset(
        dataset="cmd",
        content_id=abs(hash(get_file_name_without_extension(ref_file))) %
        (10**16),
        asset_id=abs(hash(get_file_name_without_extension(ref_file))) %
        (10**16),
        workdir_root=VmafConfig.workdir_path(),
        ref_path=ref_file,
        dis_path=dis_file,
        asset_dict={
            'width': width,
            'height': height,
            'yuv_type': fmt
        })
    assets = [asset]

    if not show_local_explanation:
        runner_class = VmafQualityRunner
    else:
        from vmaf.core.quality_runner_extra import VmafQualityRunnerWithLocalExplainer
        runner_class = VmafQualityRunnerWithLocalExplainer

    if model_path is None:
        optional_dict = None
    else:
        optional_dict = {'model_filepath': model_path}

    if phone_model:
        if optional_dict is None:
            optional_dict = {}
        optional_dict['enable_transform_score'] = True

    runner = runner_class(
        assets,
        None,
        fifo_mode=True,
        delete_workdir=True,
        result_store=None,
        optional_dict=optional_dict,
        optional_dict2=None,
    )

    # run
    runner.run()
    result = runner.results[0]

    # pooling
    if pool_method == 'harmonic_mean':
        result.set_score_aggregate_method(ListStats.harmonic_mean)
    elif pool_method == 'min':
        result.set_score_aggregate_method(np.min)
    elif pool_method == 'median':
        result.set_score_aggregate_method(np.median)
    elif pool_method == 'perc5':
        result.set_score_aggregate_method(ListStats.perc5)
    elif pool_method == 'perc10':
        result.set_score_aggregate_method(ListStats.perc10)
    elif pool_method == 'perc20':
        result.set_score_aggregate_method(ListStats.perc20)
    else:  # None or 'mean'
        pass

    # output
    if out_fmt == 'xml':
        print result.to_xml()
    elif out_fmt == 'json':
        print result.to_json()
    else:  # None or 'text'
        print str(result)

    # local explanation
    if show_local_explanation:
        import matplotlib.pyplot as plt
        runner.show_local_explanations([result])
        plt.show()

    return 0
示例#11
0
    def plot_explanations(cls, exps, assets=None, ys=None, ys_pred=None):

        # asserts
        N = cls.assert_explanations(exps, assets, ys, ys_pred)

        figs = []
        for n in range(N):
            weights = exps['feature_weights'][n]
            features = exps['features'][n]
            normalized = exps['features_normalized'][n]

            asset = assets[n] if assets is not None else None
            y = ys['label'][n] if ys is not None else None
            y_pred = ys_pred[n] if ys_pred is not None else None

            img = None
            if asset is not None:
                w, h = asset.dis_width_height
                with YuvReader(filepath=asset.dis_path, width=w, height=h,
                               yuv_type=asset.dis_yuv_type) as yuv_reader:
                    for yuv in yuv_reader:
                        img, _, _ = yuv
                        break
                assert img is not None

            title = ""
            if asset is not None:
                title += "{}\n".format(get_file_name_without_extension(asset.ref_path))
            if y is not None:
                title += "ground truth: {:.3f}\n".format(y)
            if y_pred is not None:
                title += "predicted: {:.3f}\n".format(y_pred)
            if title != "" and title[-1] == '\n':
                title = title[:-1]

            assert len(weights) == len(features)
            M = len(weights)

            fig = plt.figure()

            ax_top = plt.subplot(2, 1, 1)
            ax_left = plt.subplot(2, 3, 4)
            ax_mid = plt.subplot(2, 3, 5, sharey=ax_left)
            ax_right = plt.subplot(2, 3, 6, sharey=ax_left)

            if img is not None:
                ax_top.imshow(img, cmap='Greys_r')
            ax_top.get_xaxis().set_visible(False)
            ax_top.get_yaxis().set_visible(False)
            ax_top.set_title(title)

            pos = np.arange(M) + 0.1
            ax_left.barh(pos, features, color='b', label='feature')
            ax_left.set_xticks(np.arange(0, 1.1, 0.2))
            ax_left.set_yticks(pos + 0.35)
            ax_left.set_yticklabels(exps['feature_names'])
            ax_left.set_title('feature')

            ax_mid.barh(pos, normalized, color='g', label='fnormal')
            ax_mid.get_yaxis().set_visible(False)
            ax_mid.set_title('fnormal')

            ax_right.barh(pos, weights, color='r', label='weight')
            ax_right.get_yaxis().set_visible(False)
            ax_right.set_title('weight')

            plt.tight_layout()

            figs.append(fig)

        return figs
示例#12
0
def run_test_on_dataset(test_dataset,
                        runner_class,
                        ax,
                        result_store,
                        model_filepath,
                        parallelize=True,
                        fifo_mode=True,
                        aggregate_method=np.mean,
                        type='regressor',
                        **kwargs):
    """
    TODO: move this function under test/
    """

    if type == 'regressor':
        model_type = RegressorMixin
    elif type == 'classifier':
        model_type = ClassifierMixin
    else:
        assert False

    test_assets = read_dataset(test_dataset, **kwargs)
    test_raw_assets = None
    try:
        for test_asset in test_assets:
            assert test_asset.groundtruth is not None
    except AssertionError:
        # no groundtruth, try do subjective modeling
        subj_model_class = kwargs[
            'subj_model_class'] if 'subj_model_class' in kwargs and kwargs[
                'subj_model_class'] is not None else DmosModel
        subjective_model = subj_model_class(RawDatasetReader(test_dataset))
        subjective_model.run_modeling(**kwargs)
        test_dataset_aggregate = subjective_model.to_aggregated_dataset(
            **kwargs)
        test_raw_assets = test_assets
        test_assets = read_dataset(test_dataset_aggregate, **kwargs)

    if model_filepath is not None:
        optional_dict = {'model_filepath': model_filepath}
        if 'model_720_filepath' in kwargs and kwargs[
                'model_720_filepath'] is not None:
            optional_dict['720model_filepath'] = kwargs['model_720_filepath']
        if 'model_480_filepath' in kwargs and kwargs[
                'model_480_filepath'] is not None:
            optional_dict['480model_filepath'] = kwargs['model_480_filepath']
    else:
        optional_dict = None

    if 'enable_transform_score' in kwargs and kwargs[
            'enable_transform_score'] is not None:
        if not optional_dict:
            optional_dict = {}
        optional_dict['enable_transform_score'] = kwargs[
            'enable_transform_score']

    # run
    runner = runner_class(
        test_assets,
        None,
        fifo_mode=fifo_mode,
        delete_workdir=True,
        result_store=result_store,
        optional_dict=optional_dict,
        optional_dict2=None,
    )
    runner.run(parallelize=parallelize)
    results = runner.results

    for result in results:
        result.set_score_aggregate_method(aggregate_method)

    # plot
    groundtruths = map(lambda asset: asset.groundtruth, test_assets)
    predictions = map(lambda result: result[runner_class.get_score_key()],
                      results)
    raw_grountruths = None if test_raw_assets is None else \
        map(lambda asset: asset.raw_groundtruth, test_raw_assets)
    stats = model_type.get_stats(groundtruths,
                                 predictions,
                                 ys_label_raw=raw_grountruths)

    print 'Stats on testing data: {}'.format(model_type.format_stats(stats))

    if ax is not None:
        content_ids = map(lambda asset: asset.content_id, test_assets)

        if 'point_label' in kwargs:
            if kwargs['point_label'] == 'asset_id':
                point_labels = map(lambda asset: asset.asset_id, test_assets)
            elif kwargs['point_label'] == 'dis_path':
                point_labels = map(
                    lambda asset: get_file_name_without_extension(
                        asset.dis_path), test_assets)
            else:
                raise AssertionError("Unknown point_label {}".format(
                    kwargs['point_label']))
        else:
            point_labels = None

        model_type.plot_scatter(ax,
                                stats,
                                content_ids,
                                point_labels=point_labels)
        ax.set_xlabel('True Score')
        ax.set_ylabel("Predicted Score")
        ax.grid()
        ax.set_title("{runner}\n{stats}".format(
            dataset=test_assets[0].dataset,
            runner=runner_class.TYPE,
            stats=model_type.format_stats(stats),
        ))

    return test_assets, results
示例#13
0
 def print_assets(test_assets):
     print '\n'.join(map(
         lambda (i, asset): "Asset {i}: {name}".format(
             i=i, name=get_file_name_without_extension(asset.dis_path)),
         enumerate(test_assets)
     ))
示例#14
0
def main():
    if len(sys.argv) < 5:
        print_usage()
        return 2

    try:
        q_width = int(sys.argv[1])
        q_height = int(sys.argv[2])
        ref_file = sys.argv[3]
        dis_file = sys.argv[4]
    except ValueError:
        print_usage()
        return 2

    if q_width < 0 or q_height < 0:
        print(
            "quality_width and quality_height must be non-negative, but are {w} and {h}"
            .format(w=q_width, h=q_height))
        print_usage()
        return 2

    model_path = get_cmd_option(sys.argv, 5, len(sys.argv), '--model')

    out_fmt = get_cmd_option(sys.argv, 5, len(sys.argv), '--out-fmt')
    if not (out_fmt is None or out_fmt == 'xml' or out_fmt == 'json'
            or out_fmt == 'text'):
        print_usage()
        return 2

    ref_fmt = get_cmd_option(sys.argv, 5, len(sys.argv), '--ref-fmt')
    if not (ref_fmt is None or ref_fmt in FMTS):
        print('--ref-fmt can only have option among {}'.format(
            ', '.join(FMTS)))

    ref_width = get_cmd_option(sys.argv, 5, len(sys.argv), '--ref-width')
    ref_height = get_cmd_option(sys.argv, 5, len(sys.argv), '--ref-height')
    dis_width = get_cmd_option(sys.argv, 5, len(sys.argv), '--dis-width')
    dis_height = get_cmd_option(sys.argv, 5, len(sys.argv), '--dis-height')

    dis_fmt = get_cmd_option(sys.argv, 5, len(sys.argv), '--dis-fmt')
    if not (dis_fmt is None or dis_fmt in FMTS):
        print('--dis-fmt can only have option among {}'.format(
            ', '.join(FMTS)))

    work_dir = get_cmd_option(sys.argv, 5, len(sys.argv), '--work-dir')

    pool_method = get_cmd_option(sys.argv, 5, len(sys.argv), '--pool')
    if not (pool_method is None or pool_method in POOL_METHODS):
        print('--pool can only have option among {}'.format(
            ', '.join(POOL_METHODS)))
        return 2

    show_local_explanation = cmd_option_exists(sys.argv, 5, len(sys.argv),
                                               '--local-explain')

    phone_model = cmd_option_exists(sys.argv, 5, len(sys.argv),
                                    '--phone-model')

    enable_conf_interval = cmd_option_exists(sys.argv, 5, len(sys.argv),
                                             '--ci')

    save_plot_dir = get_cmd_option(sys.argv, 5, len(sys.argv), '--save-plot')

    if work_dir is None:
        work_dir = VmafConfig.workdir_path()

    asset_dict = {'quality_width': q_width, 'quality_height': q_height}

    if ref_fmt is None:
        asset_dict['ref_yuv_type'] = 'notyuv'
    else:
        if ref_width is None or ref_height is None:
            print(
                'if --ref-fmt is specified, both --ref-width and --ref-height must be specified'
            )
            return 2
        else:
            asset_dict['ref_yuv_type'] = ref_fmt
            asset_dict['ref_width'] = ref_width
            asset_dict['ref_height'] = ref_height

    if dis_fmt is None:
        asset_dict['dis_yuv_type'] = 'notyuv'
    else:
        if dis_width is None or dis_height is None:
            print(
                'if --dis-fmt is specified, both --dis-width and --dis-height must be specified'
            )
            return 2
        else:
            asset_dict['dis_yuv_type'] = dis_fmt
            asset_dict['dis_width'] = dis_width
            asset_dict['dis_height'] = dis_height

    if show_local_explanation and enable_conf_interval:
        print('cannot set both --local-explain and --ci flags')
        return 2

    asset = Asset(
        dataset="cmd",
        content_id=abs(hash(get_file_name_without_extension(ref_file))) %
        (10**16),
        asset_id=abs(hash(get_file_name_without_extension(ref_file))) %
        (10**16),
        workdir_root=work_dir,
        ref_path=ref_file,
        dis_path=dis_file,
        asset_dict=asset_dict,
    )
    assets = [asset]

    if show_local_explanation:
        from vmaf.core.quality_runner_extra import VmafQualityRunnerWithLocalExplainer
        runner_class = VmafQualityRunnerWithLocalExplainer
    elif enable_conf_interval:
        from vmaf.core.quality_runner import BootstrapVmafQualityRunner
        runner_class = BootstrapVmafQualityRunner
    else:
        runner_class = VmafQualityRunner

    if model_path is None:
        optional_dict = None
    else:
        optional_dict = {'model_filepath': model_path}

    if phone_model:
        if optional_dict is None:
            optional_dict = {}
        optional_dict['enable_transform_score'] = True

    runner = runner_class(
        assets,
        None,
        fifo_mode=True,
        delete_workdir=True,
        result_store=None,
        optional_dict=optional_dict,
        optional_dict2=None,
    )

    # run
    runner.run()
    result = runner.results[0]

    # pooling
    if pool_method == 'harmonic_mean':
        result.set_score_aggregate_method(ListStats.harmonic_mean)
    elif pool_method == 'min':
        result.set_score_aggregate_method(np.min)
    elif pool_method == 'median':
        result.set_score_aggregate_method(np.median)
    elif pool_method == 'perc5':
        result.set_score_aggregate_method(ListStats.perc5)
    elif pool_method == 'perc10':
        result.set_score_aggregate_method(ListStats.perc10)
    elif pool_method == 'perc20':
        result.set_score_aggregate_method(ListStats.perc20)
    else:  # None or 'mean'
        pass

    # output
    if out_fmt == 'xml':
        print(result.to_xml())
    elif out_fmt == 'json':
        print(result.to_json())
    else:  # None or 'text'
        print(str(result))

    # local explanation
    if show_local_explanation:
        runner.show_local_explanations([result])

        if save_plot_dir is None:
            DisplayConfig.show()
        else:
            DisplayConfig.show(write_to_dir=save_plot_dir)

    return 0
示例#15
0
def run_test_on_dataset(test_dataset, runner_class, ax,
                    result_store, model_filepath,
                    parallelize=True, fifo_mode=True,
                    aggregate_method=np.mean,
                    type='regressor',
                    **kwargs):
    """
    TODO: move this function under test/
    """

    if type == 'regressor':
        model_type = RegressorMixin
    elif type == 'classifier':
        model_type = ClassifierMixin
    else:
        assert False

    test_assets = read_dataset(test_dataset, **kwargs)
    test_raw_assets = None
    try:
        for test_asset in test_assets:
            assert test_asset.groundtruth is not None
    except AssertionError:
        # no groundtruth, try do subjective modeling
        subj_model_class = kwargs['subj_model_class'] if 'subj_model_class' in kwargs and kwargs['subj_model_class'] is not None else DmosModel
        subjective_model = subj_model_class(RawDatasetReader(test_dataset))
        subjective_model.run_modeling(**kwargs)
        test_dataset_aggregate = subjective_model.to_aggregated_dataset(**kwargs)
        test_raw_assets = test_assets
        test_assets = read_dataset(test_dataset_aggregate, **kwargs)

    if model_filepath is not None:
        optional_dict = {'model_filepath': model_filepath}
        if 'model_720_filepath' in kwargs and kwargs['model_720_filepath'] is not None:
            optional_dict['720model_filepath'] = kwargs['model_720_filepath']
        if 'model_480_filepath' in kwargs and kwargs['model_480_filepath'] is not None:
            optional_dict['480model_filepath'] = kwargs['model_480_filepath']
    else:
        optional_dict = None

    if 'enable_transform_score' in kwargs and kwargs['enable_transform_score'] is not None:
        if not optional_dict:
            optional_dict = {}
        optional_dict['enable_transform_score'] = kwargs['enable_transform_score']

    # run
    runner = runner_class(
        test_assets,
        None, fifo_mode=fifo_mode,
        delete_workdir=True,
        result_store=result_store,
        optional_dict=optional_dict,
        optional_dict2=None,
    )
    runner.run(parallelize=parallelize)
    results = runner.results

    for result in results:
        result.set_score_aggregate_method(aggregate_method)

    # plot
    groundtruths = map(lambda asset: asset.groundtruth, test_assets)
    predictions = map(lambda result: result[runner_class.get_score_key()], results)
    raw_grountruths = None if test_raw_assets is None else \
        map(lambda asset: asset.raw_groundtruth, test_raw_assets)
    stats = model_type.get_stats(groundtruths, predictions, ys_label_raw=raw_grountruths)

    print 'Stats on testing data: {}'.format(model_type.format_stats(stats))

    if ax is not None:
        content_ids = map(lambda asset: asset.content_id, test_assets)

        if 'point_label' in kwargs:
            if kwargs['point_label'] == 'asset_id':
                point_labels = map(lambda asset: asset.asset_id, test_assets)
            elif kwargs['point_label'] == 'dis_path':
                point_labels = map(lambda asset: get_file_name_without_extension(asset.dis_path), test_assets)
            else:
                raise AssertionError("Unknown point_label {}".format(kwargs['point_label']))
        else:
            point_labels = None

        model_type.plot_scatter(ax, stats, content_ids, point_labels=point_labels)
        ax.set_xlabel('True Score')
        ax.set_ylabel("Predicted Score")
        ax.grid()
        ax.set_title("{runner}\n{stats}".format(
            dataset=test_assets[0].dataset,
            runner=runner_class.TYPE,
            stats=model_type.format_stats(stats),
        ))

    return test_assets, results
示例#16
0
    def plot_explanations(cls, exps, assets=None, ys=None, ys_pred=None):

        # asserts
        N = cls.assert_explanations(exps, assets, ys, ys_pred)

        figs = []
        for n in range(N):
            weights = exps['feature_weights'][n]
            features = exps['features'][n]
            normalized = exps['features_normalized'][n]

            asset = assets[n] if assets is not None else None
            y = ys['label'][n] if ys is not None else None
            y_pred = ys_pred[n] if ys_pred is not None else None

            img = None
            if asset is not None:
                w, h = asset.dis_width_height
                with YuvReader(filepath=asset.dis_path,
                               width=w,
                               height=h,
                               yuv_type=asset.dis_yuv_type) as yuv_reader:
                    for yuv in yuv_reader:
                        img, _, _ = yuv
                        break
                assert img is not None

            title = ""
            if asset is not None:
                title += "{}\n".format(
                    get_file_name_without_extension(asset.ref_path))
            if y is not None:
                title += "ground truth: {:.3f}\n".format(y)
            if y_pred is not None:
                title += "predicted: {:.3f}\n".format(y_pred)
            if title != "" and title[-1] == '\n':
                title = title[:-1]

            assert len(weights) == len(features)
            M = len(weights)

            fig = plt.figure()

            ax_top = plt.subplot(2, 1, 1)
            ax_left = plt.subplot(2, 3, 4)
            ax_mid = plt.subplot(2, 3, 5, sharey=ax_left)
            ax_right = plt.subplot(2, 3, 6, sharey=ax_left)

            if img is not None:
                ax_top.imshow(img, cmap='Greys_r')
            ax_top.get_xaxis().set_visible(False)
            ax_top.get_yaxis().set_visible(False)
            ax_top.set_title(title)

            pos = np.arange(M) + 0.1
            ax_left.barh(pos, features, color='b', label='feature')
            ax_left.set_xticks(np.arange(0, 1.1, 0.2))
            ax_left.set_yticks(pos + 0.35)
            ax_left.set_yticklabels(exps['feature_names'])
            ax_left.set_title('feature')

            ax_mid.barh(pos, normalized, color='g', label='fnormal')
            ax_mid.get_yaxis().set_visible(False)
            ax_mid.set_title('fnormal')

            ax_right.barh(pos, weights, color='r', label='weight')
            ax_right.get_yaxis().set_visible(False)
            ax_right.set_title('weight')

            plt.tight_layout()

            figs.append(fig)

        return figs
示例#17
0
def run_test_on_dataset(test_dataset,
                        runner_class,
                        ax,
                        result_store,
                        model_filepath,
                        parallelize=True,
                        fifo_mode=True,
                        aggregate_method=np.mean,
                        type='regressor',
                        **kwargs):

    test_assets = read_dataset(test_dataset, **kwargs)
    test_raw_assets = None
    try:
        for test_asset in test_assets:
            assert test_asset.groundtruth is not None
    except AssertionError:
        # no groundtruth, try do subjective modeling
        from sureal.dataset_reader import RawDatasetReader
        from sureal.subjective_model import DmosModel
        subj_model_class = kwargs[
            'subj_model_class'] if 'subj_model_class' in kwargs and kwargs[
                'subj_model_class'] is not None else DmosModel
        dataset_reader_class = kwargs[
            'dataset_reader_class'] if 'dataset_reader_class' in kwargs else RawDatasetReader
        subjective_model = subj_model_class(dataset_reader_class(test_dataset))
        subjective_model.run_modeling(**kwargs)
        test_dataset_aggregate = subjective_model.to_aggregated_dataset(
            **kwargs)
        test_raw_assets = test_assets
        test_assets = read_dataset(test_dataset_aggregate, **kwargs)

    if model_filepath is not None:
        optional_dict = {'model_filepath': model_filepath}
        if 'model_720_filepath' in kwargs and kwargs[
                'model_720_filepath'] is not None:
            optional_dict['720model_filepath'] = kwargs['model_720_filepath']
        if 'model_480_filepath' in kwargs and kwargs[
                'model_480_filepath'] is not None:
            optional_dict['480model_filepath'] = kwargs['model_480_filepath']
    else:
        optional_dict = None

    if 'enable_transform_score' in kwargs and kwargs[
            'enable_transform_score'] is not None:
        if not optional_dict:
            optional_dict = {}
        optional_dict['enable_transform_score'] = kwargs[
            'enable_transform_score']

    if 'disable_clip_score' in kwargs and kwargs[
            'disable_clip_score'] is not None:
        if not optional_dict:
            optional_dict = {}
        optional_dict['disable_clip_score'] = kwargs['disable_clip_score']

    if 'subsample' in kwargs and kwargs['subsample'] is not None:
        if not optional_dict:
            optional_dict = {}
        optional_dict['subsample'] = kwargs['subsample']

    # run
    runner = runner_class(
        test_assets,
        None,
        fifo_mode=fifo_mode,
        delete_workdir=True,
        result_store=result_store,
        optional_dict=optional_dict,
        optional_dict2=None,
    )
    runner.run(parallelize=parallelize)
    results = runner.results

    for result in results:
        result.set_score_aggregate_method(aggregate_method)

    try:
        model_type = runner.get_train_test_model_class()
    except:
        if type == 'regressor':
            model_type = RegressorMixin
        elif type == 'classifier':
            model_type = ClassifierMixin
        else:
            assert False

    # plot
    groundtruths = list(map(lambda asset: asset.groundtruth, test_assets))
    predictions = list(
        map(lambda result: result[runner_class.get_score_key()], results))
    raw_grountruths = None if test_raw_assets is None else \
        list(map(lambda asset: asset.raw_groundtruth, test_raw_assets))
    groundtruths_std = None if test_assets is None else \
        list(map(lambda asset: asset.groundtruth_std, test_assets))
    try:
        predictions_bagging = list(
            map(lambda result: result[runner_class.get_bagging_score_key()],
                results))
        predictions_stddev = list(
            map(lambda result: result[runner_class.get_stddev_score_key()],
                results))
        predictions_ci95_low = list(
            map(lambda result: result[runner_class.get_ci95_low_score_key()],
                results))
        predictions_ci95_high = list(
            map(lambda result: result[runner_class.get_ci95_high_score_key()],
                results))
        predictions_all_models = list(
            map(lambda result: result[runner_class.get_all_models_score_key()],
                results))

        # need to revert the list of lists, so that the outer list has the predictions for each model separately
        predictions_all_models = np.array(predictions_all_models).T.tolist()
        num_models = np.shape(predictions_all_models)[0]

        stats = model_type.get_stats(
            groundtruths,
            predictions,
            ys_label_raw=raw_grountruths,
            ys_label_pred_bagging=predictions_bagging,
            ys_label_pred_stddev=predictions_stddev,
            ys_label_pred_ci95_low=predictions_ci95_low,
            ys_label_pred_ci95_high=predictions_ci95_high,
            ys_label_pred_all_models=predictions_all_models,
            ys_label_stddev=groundtruths_std)
    except Exception as e:
        print(
            'Stats calculation failed, using default stats calculation. Error cause: '
        )
        print(e)
        stats = model_type.get_stats(groundtruths,
                                     predictions,
                                     ys_label_raw=raw_grountruths,
                                     ys_label_stddev=groundtruths_std)
        num_models = 1

    print('Stats on testing data: {}'.format(
        model_type.format_stats_for_print(stats)))

    # printing stats if multiple models are present
    if 'SRCC_across_model_distribution' in stats \
            and 'PCC_across_model_distribution' in stats \
            and 'RMSE_across_model_distribution' in stats:
        print(
            'Stats on testing data (across multiple models, using all test indices): {}'
            .format(
                model_type.format_across_model_stats_for_print(
                    model_type.extract_across_model_stats(stats))))

    if ax is not None:
        content_ids = list(map(lambda asset: asset.content_id, test_assets))

        if 'point_label' in kwargs:
            if kwargs['point_label'] == 'asset_id':
                point_labels = list(
                    map(lambda asset: asset.asset_id, test_assets))
            elif kwargs['point_label'] == 'dis_path':
                point_labels = list(
                    map(
                        lambda asset: get_file_name_without_extension(
                            asset.dis_path), test_assets))
            else:
                raise AssertionError("Unknown point_label {}".format(
                    kwargs['point_label']))
        else:
            point_labels = None

        model_type.plot_scatter(ax,
                                stats,
                                content_ids=content_ids,
                                point_labels=point_labels,
                                **kwargs)
        ax.set_xlabel('True Score')
        ax.set_ylabel("Predicted Score")
        ax.grid()
        ax.set_title("{runner}{num_models}\n{stats}".format(
            dataset=test_assets[0].dataset,
            runner=runner_class.TYPE,
            stats=model_type.format_stats_for_plot(stats),
            num_models=", {} models".format(num_models)
            if num_models > 1 else "",
        ))

    return test_assets, results
示例#18
0
def main():
    if len(sys.argv) < 5:
        print_usage()
        return 2

    try:
        q_width = int(sys.argv[1])
        q_height = int(sys.argv[2])
        ref_file = sys.argv[3]
        dis_file = sys.argv[4]
    except ValueError:
        print_usage()
        return 2

    if q_width < 0 or q_height < 0:
        print "quality_width and quality_height must be non-negative, but are {w} and {h}".format(w=q_width, h=q_height)
        print_usage()
        return 2

    model_path = get_cmd_option(sys.argv, 5, len(sys.argv), '--model')

    out_fmt = get_cmd_option(sys.argv, 5, len(sys.argv), '--out-fmt')
    if not (out_fmt is None
            or out_fmt == 'xml'
            or out_fmt == 'json'
            or out_fmt == 'text'):
        print_usage()
        return 2

    work_dir = get_cmd_option(sys.argv, 5, len(sys.argv), '--work-dir')

    pool_method = get_cmd_option(sys.argv, 5, len(sys.argv), '--pool')
    if not (pool_method is None
            or pool_method in POOL_METHODS):
        print '--pool can only have option among {}'.format(', '.join(POOL_METHODS))
        return 2

    show_local_explanation = cmd_option_exists(sys.argv, 5, len(sys.argv), '--local-explain')

    phone_model = cmd_option_exists(sys.argv, 5, len(sys.argv), '--phone-model')

    if work_dir is None:
        work_dir = VmafConfig.workdir_path()

    asset = Asset(dataset="cmd",
                  content_id=abs(hash(get_file_name_without_extension(ref_file))) % (10 ** 16),
                  asset_id=abs(hash(get_file_name_without_extension(ref_file))) % (10 ** 16),
                  workdir_root=work_dir,
                  ref_path=ref_file,
                  dis_path=dis_file,
                  asset_dict={'quality_width':q_width, 'quality_height':q_height, 'yuv_type': 'notyuv'}
                  )
    assets = [asset]

    if not show_local_explanation:
        runner_class = VmafQualityRunner
    else:
        runner_class = VmafQualityRunnerWithLocalExplainer

    if model_path is None:
        optional_dict = None
    else:
        optional_dict = {'model_filepath':model_path}

    if phone_model:
        if optional_dict is None:
            optional_dict = {}
        optional_dict['enable_transform_score'] = True

    runner = runner_class(
        assets, None, fifo_mode=True,
        delete_workdir=True,
        result_store=None,
        optional_dict=optional_dict,
        optional_dict2=None,
    )

    # run
    runner.run()
    result = runner.results[0]

    # pooling
    if pool_method == 'harmonic_mean':
        result.set_score_aggregate_method(ListStats.harmonic_mean)
    elif pool_method == 'min':
        result.set_score_aggregate_method(np.min)
    elif pool_method == 'median':
        result.set_score_aggregate_method(np.median)
    elif pool_method == 'perc5':
        result.set_score_aggregate_method(ListStats.perc5)
    elif pool_method == 'perc10':
        result.set_score_aggregate_method(ListStats.perc10)
    elif pool_method == 'perc20':
        result.set_score_aggregate_method(ListStats.perc20)
    else: # None or 'mean'
        pass

    # output
    if out_fmt == 'xml':
        print result.to_xml()
    elif out_fmt == 'json':
        print result.to_json()
    else: # None or 'text'
        print str(result)

    # local explanation
    if show_local_explanation:
        import matplotlib.pyplot as plt
        runner.show_local_explanations([result])
        plt.show()

    return 0
示例#19
0
def vmaf_score(ref_path,
               dis_path,
               width=1920,
               height=1080,
               fmt='yuv420p',
               pool_method='mean'):
    if width < 0 or height < 0:
        raise ValueError(
            "width and height must be non-negative, but are {w} and {h}".
            format(w=width, h=height))

    if fmt not in FMTS:
        raise ValueError("不支持的类型!")

    if not (pool_method is None or pool_method in POOL_METHODS):
        raise ValueError('--pool can only have option among {}'.format(
            ', '.join(POOL_METHODS)))

    show_local_explanation = False

    enable_conf_interval = False

    if show_local_explanation and enable_conf_interval:
        print('cannot set both --local-explain and --ci flags')
        return 2

    asset = Asset(
        dataset="cmd",
        content_id=abs(hash(get_file_name_without_extension(ref_path))) %
        (10**16),
        asset_id=abs(hash(get_file_name_without_extension(ref_path))) %
        (10**16),
        workdir_root=VmafConfig.workdir_path(),
        ref_path=ref_path,
        dis_path=dis_path,
        asset_dict={
            'width': width,
            'height': height,
            'yuv_type': fmt
        })
    assets = [asset]

    if show_local_explanation:
        from vmaf.core.quality_runner_extra import VmafQualityRunnerWithLocalExplainer
        runner_class = VmafQualityRunnerWithLocalExplainer
    elif enable_conf_interval:
        from vmaf.core.quality_runner import BootstrapVmafQualityRunner
        runner_class = BootstrapVmafQualityRunner
    else:
        runner_class = VmafQualityRunner

    optional_dict = {'model_filepath': MODEL_PATH}

    runner = runner_class(
        assets,
        None,
        fifo_mode=True,
        delete_workdir=True,
        result_store=None,
        optional_dict=optional_dict,
        optional_dict2=None,
    )

    # run
    runner.run()
    result = runner.results[0]

    # pooling
    if pool_method == 'harmonic_mean':
        result.set_score_aggregate_method(ListStats.harmonic_mean)
    elif pool_method == 'min':
        result.set_score_aggregate_method(np.min)
    elif pool_method == 'median':
        result.set_score_aggregate_method(np.median)
    elif pool_method == 'perc5':
        result.set_score_aggregate_method(ListStats.perc5)
    elif pool_method == 'perc10':
        result.set_score_aggregate_method(ListStats.perc10)
    elif pool_method == 'perc20':
        result.set_score_aggregate_method(ListStats.perc20)
    else:  # None or 'mean'
        pass

    # local explanation
    if show_local_explanation:
        runner.show_local_explanations([result])

        # if save_plot_dir is None:
        #     DisplayConfig.show()
        # else:
        #     DisplayConfig.show(write_to_dir=save_plot_dir)

    return result.to_dict()
示例#20
0
def vmaf_json(fmt,
              width,
              height,
              ref_file,
              dis_file,
              json_save_path,
              model_path=None):
    # if len(sys.argv) < 6:
    #     print_usage()
    #     return 2

    # try:
    #     fmt = sys.argv[1]
    #     width = int(sys.argv[2])
    #     height = int(sys.argv[3])
    #     ref_file = sys.argv[4]
    #     dis_file = sys.argv[5]
    # except ValueError:
    #     print_usage()
    #     return 2

    if width < 0 or height < 0:
        print("width and height must be non-negative, but are {w} and {h}".
              format(w=width, h=height))
        print_usage()
        return 2

    if fmt not in FMTS:
        print_usage()
        return 2

    # model_path = get_cmd_option(sys.argv, 6, len(sys.argv), '--model')
    model_path = None
    #输出格式
    # out_fmt = get_cmd_option(sys.argv, 6, len(sys.argv), '--out-fmt')
    # if not (out_fmt is None
    #         or out_fmt == 'xml'
    #         or out_fmt == 'json'
    #         or out_fmt == 'text'):
    #     print_usage()
    #     return 2

    # pool_method = get_cmd_option(sys.argv, 6, len(sys.argv), '--pool')
    pool_method = None
    if not (pool_method is None or pool_method in POOL_METHODS):
        print('--pool can only have option among {}'.format(
            ', '.join(POOL_METHODS)))
        return 2

    # show_local_explanation = cmd_option_exists(sys.argv, 6, len(sys.argv), '--local-explain')
    show_local_explanation = None

    # phone_model = cmd_option_exists(sys.argv, 6, len(sys.argv), '--phone-model')
    phone_model = None

    # enable_conf_interval = cmd_option_exists(sys.argv, 6, len(sys.argv), '--ci')
    enable_conf_interval = None

    # save_plot_dir = get_cmd_option(sys.argv, 6, len(sys.argv), '--save-plot')
    save_plot_dir = None

    if show_local_explanation and enable_conf_interval:
        print('cannot set both --local-explain and --ci flags')
        return 2

    asset = Asset(
        dataset="cmd",
        content_id=abs(hash(get_file_name_without_extension(ref_file))) %
        (10**16),
        asset_id=abs(hash(get_file_name_without_extension(ref_file))) %
        (10**16),
        workdir_root=VmafConfig.workdir_path(),
        ref_path=ref_file,
        dis_path=dis_file,
        asset_dict={
            'width': width,
            'height': height,
            'yuv_type': fmt
        })
    assets = [asset]

    if show_local_explanation:
        from vmaf.core.quality_runner_extra import VmafQualityRunnerWithLocalExplainer
        runner_class = VmafQualityRunnerWithLocalExplainer
    elif enable_conf_interval:
        from vmaf.core.quality_runner import BootstrapVmafQualityRunner
        runner_class = BootstrapVmafQualityRunner
    else:
        runner_class = VmafQualityRunner

    if model_path is None:
        optional_dict = None
    else:
        optional_dict = {'model_filepath': model_path}

    if phone_model:
        if optional_dict is None:
            optional_dict = {}
        optional_dict['enable_transform_score'] = True

    runner = runner_class(
        assets,
        None,
        fifo_mode=True,
        delete_workdir=True,
        result_store=None,
        optional_dict=optional_dict,
        optional_dict2=None,
    )

    # run
    runner.run()
    result = runner.results[0]

    # pooling
    if pool_method == 'harmonic_mean':
        result.set_score_aggregate_method(ListStats.harmonic_mean)
    elif pool_method == 'min':
        result.set_score_aggregate_method(np.min)
    elif pool_method == 'median':
        result.set_score_aggregate_method(np.median)
    elif pool_method == 'perc5':
        result.set_score_aggregate_method(ListStats.perc5)
    elif pool_method == 'perc10':
        result.set_score_aggregate_method(ListStats.perc10)
    elif pool_method == 'perc20':
        result.set_score_aggregate_method(ListStats.perc20)
    else:  # None or 'mean'
        pass

    # output
    #json_save_path
    yuv_temp_name = dis_file.split("/")[-1].split(".yuv")[0] + ".json"
    yuv_temp_path = os.path.join(json_save_path, yuv_temp_name)
    with open(yuv_temp_path, 'w+') as f:
        f.write(result.to_json())
    print("written " + yuv_temp_path)
    # if out_fmt == 'xml':
    #     print(result.to_xml())
    # elif out_fmt == 'json':
    #     print(result.to_json())
    # else: # None or 'text'
    #     print(str(result))

    # local explanation
    if show_local_explanation:
        runner.show_local_explanations([result])

        if save_plot_dir is None:
            DisplayConfig.show()
        else:
            DisplayConfig.show(write_to_dir=save_plot_dir)

    return 0
示例#21
0
def run_test_on_dataset(test_dataset,
                        runner_class,
                        ax,
                        result_store,
                        model_filepath,
                        parallelize=True,
                        fifo_mode=True,
                        aggregate_method=np.mean,
                        type='regressor',
                        **kwargs):

    test_assets = read_dataset(test_dataset, **kwargs)
    test_raw_assets = None
    try:
        for test_asset in test_assets:
            assert test_asset.groundtruth is not None
    except AssertionError:
        # no groundtruth, try do subjective modeling
        from sureal.dataset_reader import RawDatasetReader
        from sureal.subjective_model import DmosModel
        subj_model_class = kwargs[
            'subj_model_class'] if 'subj_model_class' in kwargs and kwargs[
                'subj_model_class'] is not None else DmosModel
        dataset_reader_class = kwargs[
            'dataset_reader_class'] if 'dataset_reader_class' in kwargs else RawDatasetReader
        subjective_model = subj_model_class(dataset_reader_class(test_dataset))
        subjective_model.run_modeling(**kwargs)
        test_dataset_aggregate = subjective_model.to_aggregated_dataset(
            **kwargs)
        test_raw_assets = test_assets
        test_assets = read_dataset(test_dataset_aggregate, **kwargs)

    if model_filepath is not None:
        optional_dict = {'model_filepath': model_filepath}
        if 'model_720_filepath' in kwargs and kwargs[
                'model_720_filepath'] is not None:
            optional_dict['720model_filepath'] = kwargs['model_720_filepath']
        if 'model_480_filepath' in kwargs and kwargs[
                'model_480_filepath'] is not None:
            optional_dict['480model_filepath'] = kwargs['model_480_filepath']
    else:
        optional_dict = None

    if 'enable_transform_score' in kwargs and kwargs[
            'enable_transform_score'] is not None:
        if not optional_dict:
            optional_dict = {}
        optional_dict['enable_transform_score'] = kwargs[
            'enable_transform_score']

    if 'disable_clip_score' in kwargs and kwargs[
            'disable_clip_score'] is not None:
        if not optional_dict:
            optional_dict = {}
        optional_dict['disable_clip_score'] = kwargs['disable_clip_score']

    if 'subsample' in kwargs and kwargs['subsample'] is not None:
        if not optional_dict:
            optional_dict = {}
        optional_dict['subsample'] = kwargs['subsample']

    if 'additional_optional_dict' in kwargs and kwargs[
            'additional_optional_dict'] is not None:
        assert isinstance(kwargs['additional_optional_dict'], dict)
        if not optional_dict:
            optional_dict = {}
        optional_dict.update(kwargs['additional_optional_dict'])

    if 'processes' in kwargs and kwargs['processes'] is not None:
        assert isinstance(kwargs['processes'], int)
        processes = kwargs['processes']
    else:
        processes = None
    if processes is not None:
        assert parallelize is True

    # run
    runner = runner_class(
        test_assets,
        None,
        fifo_mode=fifo_mode,
        delete_workdir=True,
        result_store=result_store,
        optional_dict=optional_dict,
        optional_dict2=None,
    )
    runner.run(parallelize=parallelize, processes=processes)
    results = runner.results
    list_score_key = results[0].get_ordered_list_score_key()
    COLUMNS = [
        'dataset', 'content_id', 'asset_id', 'ref_name', 'dis_name', 'width',
        'height', 'DMOS'
    ]
    COLUMNS.extend(list_score_key)

    df = pd.DataFrame(columns=tuple(COLUMNS))
    for result in results:
        result.set_score_aggregate_method(aggregate_method)
        rows = []
        # result[runner_class.get_score_key()]
        list_score_key = result.get_ordered_list_score_key()
        list_aggregate_score = list(
            map(lambda key: result[key], list_score_key))
        total_score = list_aggregate_score[-1]
        #width,height=result.asset.quality_width_height()
        #for score_key, score in zip(list_score_key[:-1], list_aggregate_score[:-1]):
        row = [
            result.asset.dataset, result.asset.content_id,
            result.asset.asset_id,
            get_file_name_with_extension(result.asset.ref_path),
            get_file_name_with_extension(result.asset.dis_path),
            result.asset.asset_dict['dis_width'],
            result.asset.asset_dict['dis_height'], result.asset.groundtruth
        ]
        row.extend(list_aggregate_score)
        #rows.append(row)
        df2 = pd.DataFrame([row], columns=COLUMNS)
        df = df.append(df2)
    df.to_csv('/Users/jessica/CMT309/Project-VMAF/vmaf-master/result_4k.csv')

    try:
        model_type = runner.get_train_test_model_class()
    except:
        if type == 'regressor':
            model_type = RegressorMixin
        elif type == 'classifier':
            model_type = ClassifierMixin
        else:
            assert False

    split_test_indices_for_perf_ci = kwargs['split_test_indices_for_perf_ci'] \
        if 'split_test_indices_for_perf_ci' in kwargs else False

    # plot
    groundtruths = list(map(lambda asset: asset.groundtruth, test_assets))
    predictions = list(
        map(lambda result: result[runner_class.get_score_key()], results))
    raw_grountruths = None if test_raw_assets is None else \
        list(map(lambda asset: asset.raw_groundtruth, test_raw_assets))
    groundtruths_std = None if test_assets is None else \
        list(map(lambda asset: asset.groundtruth_std, test_assets))
    try:
        predictions_bagging = list(
            map(lambda result: result[runner_class.get_bagging_score_key()],
                results))
        predictions_stddev = list(
            map(lambda result: result[runner_class.get_stddev_score_key()],
                results))
        predictions_ci95_low = list(
            map(lambda result: result[runner_class.get_ci95_low_score_key()],
                results))
        predictions_ci95_high = list(
            map(lambda result: result[runner_class.get_ci95_high_score_key()],
                results))
        predictions_all_models = list(
            map(lambda result: result[runner_class.get_all_models_score_key()],
                results))

        # need to revert the list of lists, so that the outer list has the predictions for each model separately
        predictions_all_models = np.array(predictions_all_models).T.tolist()
        num_models = np.shape(predictions_all_models)[0]

        stats = model_type.get_stats(
            groundtruths,
            predictions,
            ys_label_raw=raw_grountruths,
            ys_label_pred_bagging=predictions_bagging,
            ys_label_pred_stddev=predictions_stddev,
            ys_label_pred_ci95_low=predictions_ci95_low,
            ys_label_pred_ci95_high=predictions_ci95_high,
            ys_label_pred_all_models=predictions_all_models,
            ys_label_stddev=groundtruths_std,
            split_test_indices_for_perf_ci=split_test_indices_for_perf_ci)
    except Exception as e:
        print(
            'Warning: stats calculation failed, fall back to default stats calculation: {}'
            .format(e))
        stats = model_type.get_stats(
            groundtruths,
            predictions,
            ys_label_raw=raw_grountruths,
            ys_label_stddev=groundtruths_std,
            split_test_indices_for_perf_ci=split_test_indices_for_perf_ci)
        num_models = 1

    print('Stats on testing data: {}'.format(
        model_type.format_stats_for_print(stats)))

    # printing stats if multiple models are present
    if 'SRCC_across_model_distribution' in stats \
            and 'PCC_across_model_distribution' in stats \
            and 'RMSE_across_model_distribution' in stats:
        print(
            'Stats on testing data (across multiple models, using all test indices): {}'
            .format(
                model_type.format_across_model_stats_for_print(
                    model_type.extract_across_model_stats(stats))))

    if split_test_indices_for_perf_ci:
        print('Stats on testing data (single model, multiple test sets): {}'.
              format(
                  model_type.format_stats_across_test_splits_for_print(
                      model_type.extract_across_test_splits_stats(stats))))

    if ax is not None:
        content_ids = list(map(lambda asset: asset.content_id, test_assets))

        if 'point_label' in kwargs:
            if kwargs['point_label'] == 'asset_id':
                point_labels = list(
                    map(lambda asset: asset.asset_id, test_assets))
            elif kwargs['point_label'] == 'dis_path':
                point_labels = list(
                    map(
                        lambda asset: get_file_name_without_extension(
                            asset.dis_path), test_assets))
            else:
                raise AssertionError("Unknown point_label {}".format(
                    kwargs['point_label']))
        else:
            point_labels = None

        model_type.plot_scatter(ax,
                                stats,
                                content_ids=content_ids,
                                point_labels=point_labels,
                                **kwargs)
        ax.set_xlabel('True Score')
        ax.set_ylabel("Predicted Score")
        ax.grid()
        ax.set_title("{runner}{num_models}\n{stats}".format(
            dataset=test_assets[0].dataset,
            runner=runner_class.TYPE,
            stats=model_type.format_stats_for_plot(stats),
            num_models=", {} models".format(num_models)
            if num_models > 1 else "",
        ))

    return test_assets, results