コード例 #1
0
ファイル: routine_test.py プロジェクト: Netflix/vmaf
 def test_train_test_on_raw_dataset_with_dis1st_thr(self):
     train_dataset = import_python_file(
         config.ROOT + '/python/test/resource/raw_dataset_sample.py')
     model_param = import_python_file(
         config.ROOT + '/python/test/resource/model_param_sample.py')
     feature_param = import_python_file(
         config.ROOT + '/python/test/resource/feature_param_sample.py')
     train_fassembler, train_assets, train_stats, \
     test_fassembler, test_assets, test_stats, _ = \
         train_test_vmaf_on_dataset(
             train_dataset=train_dataset, test_dataset=train_dataset,
                      feature_param=feature_param, model_param=model_param,
                      train_ax=None, test_ax=None, result_store=None,
                      parallelize=True,
                      logger=None,
                      fifo_mode=True,
                      output_model_filepath=self.output_model_filepath)
     self.train_fassembler = train_fassembler
     self.assertTrue(os.path.exists(self.output_model_filepath))
     self.assertItemsEqual(train_stats['ys_label_pred'],
                             [93.565459224020742, 60.451618249440827,
                              93.565460383297108, 92.417462071278933])
     self.assertItemsEqual(test_stats['ys_label_pred'],
                             [93.565459224020742, 60.451618249440827,
                              93.565460383297108, 92.417462071278933])
コード例 #2
0
ファイル: routine_test.py プロジェクト: Netflix/vmaf
 def test_train_test_on_dataset_with_dis1st_thr(self):
     train_dataset = import_python_file(
         config.ROOT + '/python/test/resource/dataset_sample.py')
     model_param = import_python_file(
         config.ROOT + '/python/test/resource/model_param_sample.py')
     feature_param = import_python_file(
         config.ROOT + '/python/test/resource/feature_param_sample.py')
     train_fassembler, train_assets, train_stats, \
     test_fassembler, test_assets, test_stats, _ = \
         train_test_vmaf_on_dataset(
             train_dataset=train_dataset, test_dataset=train_dataset,
                      feature_param=feature_param, model_param=model_param,
                      train_ax=None, test_ax=None, result_store=None,
                      parallelize=True,
                      logger=None,
                      fifo_mode=True,
                      output_model_filepath=self.output_model_filepath,
                      )
     self.train_fassembler = train_fassembler
     self.assertTrue(os.path.exists(self.output_model_filepath))
     self.assertItemsEqual(train_stats['ys_label_pred'],
                             [90.753010402770798, 59.223801498461015,
                              90.753011435798058, 89.270176556597008])
     self.assertItemsEqual(test_stats['ys_label_pred'],
                             [90.753010402770798, 59.223801498461015,
                              90.753011435798058, 89.270176556597008])
コード例 #3
0
 def test_train_test_on_dataset_with_dis1st_thr(self):
     train_dataset = import_python_file(
         config.ROOT + '/python/test/resource/dataset_sample.py')
     model_param = import_python_file(
         config.ROOT + '/python/test/resource/model_param_sample.py')
     feature_param = import_python_file(
         config.ROOT + '/python/test/resource/feature_param_sample.py')
     train_fassembler, train_assets, train_stats, \
     test_fassembler, test_assets, test_stats, _ = \
         train_test_vmaf_on_dataset(
             train_dataset=train_dataset, test_dataset=train_dataset,
                      feature_param=feature_param, model_param=model_param,
                      train_ax=None, test_ax=None, result_store=None,
                      parallelize=True,
                      logger=None,
                      fifo_mode=True,
                      output_model_filepath=self.output_model_filepath,
                      )
     self.train_fassembler = train_fassembler
     self.assertTrue(os.path.exists(self.output_model_filepath))
     self.assertAlmostEqual(train_stats['ys_label_pred'][0],
                            90.753010402770798,
                            places=3)
     self.assertAlmostEqual(test_stats['ys_label_pred'][0],
                            90.753010402770798,
                            places=3)
コード例 #4
0
 def test_train_test_on_raw_dataset_with_dis1st_thr(self):
     train_dataset = import_python_file(
         config.ROOT + '/python/test/resource/raw_dataset_sample.py')
     model_param = import_python_file(
         config.ROOT + '/python/test/resource/model_param_sample.py')
     feature_param = import_python_file(
         config.ROOT + '/python/test/resource/feature_param_sample.py')
     train_fassembler, train_assets, train_stats, \
     test_fassembler, test_assets, test_stats, _ = \
         train_test_vmaf_on_dataset(
             train_dataset=train_dataset, test_dataset=train_dataset,
                      feature_param=feature_param, model_param=model_param,
                      train_ax=None, test_ax=None, result_store=None,
                      parallelize=True,
                      logger=None,
                      fifo_mode=True,
                      output_model_filepath=self.output_model_filepath)
     self.train_fassembler = train_fassembler
     self.assertTrue(os.path.exists(self.output_model_filepath))
     self.assertItemsEqual(train_stats['ys_label_pred'], [
         93.565459224020742, 60.451618249440827, 93.565460383297108,
         92.417462071278933
     ])
     self.assertItemsEqual(test_stats['ys_label_pred'], [
         93.565459224020742, 60.451618249440827, 93.565460383297108,
         92.417462071278933
     ])
コード例 #5
0
ファイル: routine_test.py プロジェクト: yuhjay/vmaf
 def test_train_test_on_dataset_with_dis1st_thr(self):
     model_param = import_python_file(
         config.ROOT + '/python/test/resource/model_param_sample.py')
     feature_param = import_python_file(
         config.ROOT + '/python/test/resource/feature_param_sample.py')
     train_fassembler, train_assets, train_stats, \
     test_fassembler, test_assets, test_stats, _ = \
         train_test_vmaf_on_dataset(
             train_dataset=self.train_dataset, test_dataset=None,
                      feature_param=feature_param, model_param=model_param,
                      train_ax=None, test_ax=None, result_store=None,
                      parallelize=True,
                      logger=None,
                      fifo_mode=True,
                      output_model_filepath=self.output_model_filepath,
                      )
     self.train_fassembler = train_fassembler
     self.assertTrue(os.path.exists(self.output_model_filepath))
     self.assertItemsEqual(train_stats['ys_label_pred'],
                             [90.753010402770798, 59.223801498461015,
                              90.753011435798058, 89.270176556597008])
コード例 #6
0
def main():

    if len(sys.argv) < 5:
        print_usage()
        return 2

    try:
        train_dataset_filepath = sys.argv[1]
        feature_param_filepath = sys.argv[2]
        model_param_filepath = sys.argv[3]
        output_model_filepath = sys.argv[4]
    except ValueError:
        print_usage()
        return 2

    try:
        train_dataset = import_python_file(train_dataset_filepath)
        feature_param = import_python_file(feature_param_filepath)
        model_param = import_python_file(model_param_filepath)
    except Exception as e:
        print "Error: " + str(e)
        return 1

    cache_result = cmd_option_exists(sys.argv, 3, len(sys.argv), '--cache-result')
    parallelize = cmd_option_exists(sys.argv, 3, len(sys.argv), '--parallelize')

    pool_method = get_cmd_option(sys.argv, 3, len(sys.argv), '--pool')
    if not (pool_method is None
            or pool_method in POOL_METHODS):
        print '--pool can only have option among {}'.format(', '.join(POOL_METHODS))
        return 2

    if cache_result:
        result_store = FileSystemResultStore()
    else:
        result_store = None

    # pooling
    if pool_method == 'harmonic_mean':
        aggregate_method = ListStats.harmonic_mean
    elif pool_method == 'min':
        aggregate_method = np.min
    elif pool_method == 'median':
        aggregate_method = np.median
    elif pool_method == 'perc5':
        aggregate_method = ListStats.perc5
    elif pool_method == 'perc10':
        aggregate_method = ListStats.perc10
    elif pool_method == 'perc20':
        aggregate_method = ListStats.perc20
    else: # None or 'mean'
        aggregate_method = np.mean

    logger = None

    try:
        import matplotlib.pyplot as plt
        fig, ax = plt.subplots(figsize=(5, 5), nrows=1, ncols=1)

        train_test_vmaf_on_dataset(train_dataset=train_dataset, test_dataset=None,
                                   feature_param=feature_param, model_param=model_param,
                                   train_ax=ax, test_ax=None,
                                   result_store=result_store,
                                   parallelize=parallelize,
                                   logger=logger,
                                   output_model_filepath=output_model_filepath,
                                   aggregate_method=aggregate_method
                                   )

        bbox = {'facecolor':'white', 'alpha':0.5, 'pad':20}
        ax.annotate('Training Set', xy=(0.1, 0.85), xycoords='axes fraction', bbox=bbox)

        # ax.set_xlim([-10, 110])
        # ax.set_ylim([-10, 110])

        plt.tight_layout()
        plt.show()
    except ImportError:
        print_matplotlib_warning()
        train_test_vmaf_on_dataset(train_dataset=train_dataset, test_dataset=None,
                                   feature_param=feature_param, model_param=model_param,
                                   train_ax=None, test_ax=None,
                                   result_store=result_store,
                                   parallelize=parallelize,
                                   logger=logger,
                                   output_model_filepath=output_model_filepath,
                                   aggregate_method=aggregate_method
                                   )

    return 0
コード例 #7
0
ファイル: run_vmaf_training.py プロジェクト: Netflix/vmaf
def main():

    if len(sys.argv) < 5:
        print_usage()
        return 2

    try:
        train_dataset_filepath = sys.argv[1]
        feature_param_filepath = sys.argv[2]
        model_param_filepath = sys.argv[3]
        output_model_filepath = sys.argv[4]
    except ValueError:
        print_usage()
        return 2

    try:
        train_dataset = import_python_file(train_dataset_filepath)
        feature_param = import_python_file(feature_param_filepath)
        model_param = import_python_file(model_param_filepath)
    except Exception as e:
        print "Error: " + str(e)
        return 1

    cache_result = cmd_option_exists(sys.argv, 3, len(sys.argv), '--cache-result')
    parallelize = cmd_option_exists(sys.argv, 3, len(sys.argv), '--parallelize')

    pool_method = get_cmd_option(sys.argv, 3, len(sys.argv), '--pool')
    if not (pool_method is None
            or pool_method in POOL_METHODS):
        print '--pool can only have option among {}'.format(', '.join(POOL_METHODS))
        return 2

    subj_model = get_cmd_option(sys.argv, 3, len(sys.argv), '--subj-model')

    try:
        if subj_model is not None:
            subj_model_class = SubjectiveModel.find_subclass(subj_model)
        else:
            subj_model_class = None
    except Exception as e:
        print "Error: " + str(e)
        return 1

    if cache_result:
        result_store = FileSystemResultStore()
    else:
        result_store = None

    # pooling
    if pool_method == 'harmonic_mean':
        aggregate_method = ListStats.harmonic_mean
    elif pool_method == 'min':
        aggregate_method = np.min
    elif pool_method == 'median':
        aggregate_method = np.median
    elif pool_method == 'perc5':
        aggregate_method = ListStats.perc5
    elif pool_method == 'perc10':
        aggregate_method = ListStats.perc10
    elif pool_method == 'perc20':
        aggregate_method = ListStats.perc20
    else: # None or 'mean'
        aggregate_method = np.mean

    logger = None

    try:
        import matplotlib.pyplot as plt
        fig, ax = plt.subplots(figsize=(5, 5), nrows=1, ncols=1)

        train_test_vmaf_on_dataset(train_dataset=train_dataset, test_dataset=None,
                                   feature_param=feature_param, model_param=model_param,
                                   train_ax=ax, test_ax=None,
                                   result_store=result_store,
                                   parallelize=parallelize,
                                   logger=logger,
                                   output_model_filepath=output_model_filepath,
                                   aggregate_method=aggregate_method,
                                   subj_model_class=subj_model_class,
                                   )

        bbox = {'facecolor':'white', 'alpha':0.5, 'pad':20}
        ax.annotate('Training Set', xy=(0.1, 0.85), xycoords='axes fraction', bbox=bbox)

        # ax.set_xlim([-10, 110])
        # ax.set_ylim([-10, 110])

        plt.tight_layout()
        plt.show()
    except ImportError:
        print_matplotlib_warning()
        train_test_vmaf_on_dataset(train_dataset=train_dataset, test_dataset=None,
                                   feature_param=feature_param, model_param=model_param,
                                   train_ax=None, test_ax=None,
                                   result_store=result_store,
                                   parallelize=parallelize,
                                   logger=logger,
                                   output_model_filepath=output_model_filepath,
                                   aggregate_method=aggregate_method,
                                   subj_model_class=subj_model_class,
                                   )

    return 0