예제 #1
0
    def test_add_metrics_file(self):
        hashfs_path = os.path.join(self.tmp_dir, 'objectsfs')
        test_config = yaml_load('hdata/config.yaml')
        local_repo = LocalRepository(test_config,
                                     hashfs_path,
                                     repo_type=MODELS)
        spec_path = os.path.join(self.tmp_dir, 'model-ex.spec')
        shutil.copy('hdata/dataset-ex.spec', spec_path)
        spec_file = yaml_load(spec_path)
        model = spec_file[DATASET_SPEC_KEY].copy()
        del spec_file[DATASET_SPEC_KEY]
        spec_file[MODEL_SPEC_KEY] = model
        yaml_save(spec_file, spec_path)
        metrics_file_path = os.path.join(self.tmp_dir, 'metrics.csv')
        self.create_csv_file(metrics_file_path, {
            'metric_a': 10,
            'metric_b': 9
        })
        local_repo.add_metrics(spec_path, (), metrics_file_path)

        test_spec_file = yaml_load(spec_path)
        self.assertEqual(
            test_spec_file[MODEL_SPEC_KEY]['metrics'].get('metric_a', ''),
            10.0)
        self.assertEqual(
            test_spec_file[MODEL_SPEC_KEY]['metrics'].get('metric_b', ''), 9.0)
예제 #2
0
 def test_add_metrics_wrong_entity(self):
     hashfs_path = os.path.join(self.tmp_dir, 'objectsfs')
     test_config = yaml_load('hdata/config.yaml')
     local_repo = LocalRepository(test_config, hashfs_path)
     spec_path = os.path.join(self.tmp_dir, 'dataset-ex.spec')
     shutil.copy('hdata/dataset-ex.spec', spec_path)
     local_repo.add_metrics(spec_path,
                            (('metric_a', '10'), ('metric_b', '9')), None)
     test_spec_file = yaml_load(spec_path)
     self.assertFalse('metrics' in test_spec_file[DATASET_SPEC_KEY])
예제 #3
0
    def test_add_metrics_with_none_metrics_options(self):
        hashfs_path = os.path.join(self.tmp_dir, 'objectsfs')
        test_config = yaml_load('hdata/config.yaml')
        local_repo = LocalRepository(test_config,
                                     hashfs_path,
                                     repo_type=MODELS)
        spec_path = os.path.join(self.tmp_dir, 'model-ex.spec')
        shutil.copy('hdata/dataset-ex.spec', spec_path)
        spec_file = yaml_load(spec_path)
        model = spec_file[DATASET_SPEC_KEY].copy()
        del spec_file[DATASET_SPEC_KEY]
        spec_file[MODEL_SPEC_KEY] = model
        yaml_save(spec_file, spec_path)
        local_repo.add_metrics(spec_path, (), None)

        test_spec_file = yaml_load(spec_path)
        self.assertFalse('metrics' in test_spec_file[MODEL_SPEC_KEY])
예제 #4
0
    def test_add_metrics(self):
        hashfs_path = os.path.join(self.tmp_dir, 'objectsfs')
        test_config = yaml_load('hdata/config.yaml')
        local_repo = LocalRepository(test_config,
                                     hashfs_path,
                                     repo_type=MODELS)
        spec_path = os.path.join(self.tmp_dir, 'model-ex.spec')
        shutil.copy('hdata/dataset-ex.spec', spec_path)
        spec_file = yaml_load(spec_path)
        model = spec_file[DATASET_SPEC_KEY].copy()
        del spec_file[DATASET_SPEC_KEY]
        spec_file[MODEL_SPEC_KEY] = model
        yaml_save(spec_file, spec_path)
        local_repo.add_metrics(spec_path,
                               (('metric_a', '10'), ('metric_b', '9')), None)

        test_spec_file = yaml_load(spec_path)
        self.assertTrue(test_spec_file[MODEL_SPEC_KEY]['metrics'].get(
            'metric_a', '') == 10.0)
        self.assertTrue(test_spec_file[MODEL_SPEC_KEY]['metrics'].get(
            'metric_b', '') == 9.0)