Пример #1
0
def test_csv_data_relative(tmp_path):
    config = make_basic_config()
    data_dir, file_names = make_sample_data(tmp_path)

    config["dataType"] = "tabular"

    csv_data = {
        "tabularData": {
            "sources": [{
                "root": "relative",
                "path": data_dir / file_names[0]
            }],
            "labelIndex": 20
        },
    }

    config.update(csv_data)

    ds = DatasetConfig(config, Path(tmp_path))
    assert ds.is_tabular_type()

    paths, index = ds.get_resolved_tabular_source_paths_and_labels()

    assert len(paths) == 1
    assert paths[0] == Path(tmp_path) / data_dir / file_names[0]
    assert index == 20
Пример #2
0
def test_csv_data_workspace(tmp_path):
    config = make_basic_config()
    data_dir, file_names = make_sample_data(tmp_path)

    config["dataType"] = "tabular"

    csv_data = {
        "tabularData": {
            "sources": [{
                "root": "workspace",
                "path": data_dir / file_names[0]
            }, {
                "root": "workspace",
                "path": data_dir / file_names[1]
            }],
            "labelIndex":
            20
        },
    }

    config.update(csv_data)

    juneberry.WORKSPACE_ROOT = tmp_path
    ds = DatasetConfig(config, Path('.'))
    assert ds.is_tabular_type()

    paths, index = ds.get_resolved_tabular_source_paths_and_labels()

    assert len(paths) == len(file_names)
    for i in range(len(file_names)):
        assert paths[i] == Path(
            juneberry.WORKSPACE_ROOT) / data_dir / file_names[i]

    assert index == 20
Пример #3
0
def test_sampling():
    config = make_basic_config()
    sampling_data = {
        "sampling": {
            "algorithm": "foo",
            "args": "none"
        },
    }

    config.update(sampling_data)

    ds = DatasetConfig(config, Path('.'))
    assert ds.has_sampling()
    assert ds.sampling['algorithm'] == "foo"
    assert ds.sampling['args'] == "none"
Пример #4
0
def test_generate_image_validation_split():
    os.listdir = mock_list_image_dir

    data_set_struct = make_basic_data_set_image_config()
    data_set_config = DatasetConfig(data_set_struct)

    train_struct = test_training_config.make_basic_config()
    train_struct['validation'] = {
        "algorithm": "randomFraction",
        "arguments": {
            "seed": 1234,
            "fraction": 0.3333333
        }
    }
    train_config = TrainingConfig('', train_struct)
    dm = jbfs.DataManager({})

    train_list, val_list = jb_data.generate_image_list('data_root',
                                                       data_set_config,
                                                       train_config, dm)
    assert len(train_list) == 8
    assert len(val_list) == 4

    # NOTE: Another fragile secret we know is the order from the validation is is reversed
    assert_correct_list(train_list, [1, 2, 4, 5], [1, 2, 3, 4])
    assert_correct_list(val_list, [3, 0], [5, 0])
Пример #5
0
def test_load_tabular_data_with_sampling(tmp_path):
    correct = fill_tabular_tempdir(tmp_path)
    juneberry.WORKSPACE_ROOT = Path(tmp_path) / 'myworkspace'
    juneberry.DATA_ROOT = Path(tmp_path) / 'mydataroot'

    # We only need to test one sample because the sampling core is tested elsewhere
    data_set_struct = make_basic_data_set_tabular_config()
    data_set_struct.update(
        make_sample_stanza("randomQuantity", {
            'seed': 1234,
            'count': 3
        }))
    data_set_config = DatasetConfig(data_set_struct,
                                    Path(tmp_path) / 'myrelative')

    train_list, val_list = jb_data.load_tabular_data(None, data_set_config)

    # THe sample data is three files each with 4 sample with 2 in each class.
    # THe default validation split is 2.  So 3 * 4 / 2 = 6 per list
    assert len(train_list) == 6
    assert len(val_list) == 0

    # Now, make sure they are in each one, removing as we go
    for data, label in train_list:
        assert correct[int(label)][int(data[0])] == int(data[1])
        del correct[int(label)][int(data[0])]

    # At this point we should have three unused entries of each class
    assert len(correct[0]) == 3
    assert len(correct[1]) == 3
Пример #6
0
def test_generate_image_sample_fraction():
    # If we pass in sampling count we should just get those
    # We know how the internal randomizer works.  We know it uses random.sample on both
    # sets in order.  This is a secret and fragile to this test.
    # With a seed of 1234 and two pulls of sampling with a count of 2, it pulls [3,0] and [0,5]
    os.listdir = mock_list_image_dir

    data_set_struct = make_basic_data_set_image_config()
    data_set_struct.update(
        make_sample_stanza("randomFraction", {
            'seed': 1234,
            'fraction': 0.3333333333
        }))

    data_set_config = DatasetConfig(data_set_struct)
    dm = jbfs.DataManager({})

    train_list, val_list = jb_data.generate_image_list('data_root',
                                                       data_set_config, None,
                                                       dm)
    assert len(train_list) == 4
    assert len(val_list) == 0

    # Make sure they are in this order
    assert_correct_list(train_list, [3, 0], [0, 5])
Пример #7
0
def test_load_tabular_data_with_validation(tmp_path):
    correct = fill_tabular_tempdir(tmp_path)
    juneberry.WORKSPACE_ROOT = Path(tmp_path) / 'myworkspace'
    juneberry.DATA_ROOT = Path(tmp_path) / 'mydataroot'

    data_set_struct = make_basic_data_set_tabular_config()
    data_set_config = DatasetConfig(data_set_struct,
                                    Path(tmp_path) / 'myrelative')

    train_struct = test_training_config.make_basic_config()
    train_config = TrainingConfig('', train_struct)

    train_list, val_list = jb_data.load_tabular_data(train_config,
                                                     data_set_config)

    # THe sample data is three files each with 4 sample with 2 in each class.
    # THe default validation split is 2.  So 3 * 4 / 2 = 6 per list
    assert len(train_list) == 6
    assert len(val_list) == 6

    # Now, make sure they are in each one, removing as we go
    for data, label in train_list:
        assert correct[int(label)][int(data[0])] == int(data[1])
        del correct[int(label)][int(data[0])]

    assert len(correct[0]) == 3
    assert len(correct[1]) == 3

    for data, label in val_list:
        assert correct[int(label)][int(data[0])] == int(data[1])
        del correct[int(label)][int(data[0])]

    assert len(correct[0]) == 0
    assert len(correct[1]) == 0
Пример #8
0
    def test_csv_bad_label_index(self):
        config = make_basic_config(False)
        del config['tabularData']['labelIndex']

        with self.assertRaises(SystemExit), self.assertLogs(
                level='ERROR') as log:
            ds = DatasetConfig(config, Path('.'))
        self.assert_error(log, "labelIndex")
Пример #9
0
    def test_csv_data_missing_path(self):
        config = make_basic_config(False)
        config['tabularData']['sources'] = [{"root": "dataroot"}]

        with self.assertRaises(SystemExit), self.assertLogs(
                level='ERROR') as log:
            ds = DatasetConfig(config, Path('.'))
        self.assert_error(log, "path")
Пример #10
0
    def test_image_data_missing_label(self):
        config = make_basic_config()
        config['imageData']['sources'] = {"directory": "some/path"}

        with self.assertRaises(SystemExit), self.assertLogs(
                level='ERROR') as log:
            ds = DatasetConfig(config, Path('.'))
        self.assert_error(log, "label")
Пример #11
0
def test_image_data():
    config = make_basic_config()
    image_data = {
        "imageData": {
            "taskType": "classification",
            "sources": [{
                "directory": "some/path",
                "label": 4
            }],
        },
    }

    config.update(image_data)

    ds = DatasetConfig(config, Path('.'))
    assert ds.is_image_type()
    assert len(ds.get_image_sources()) == 1
Пример #12
0
def test_csv_glob(tmp_path):
    # The path structure should support globbing
    # So, use the tmp_path as the dataroot and slap in some files

    config = make_basic_config()
    data_dir, file_names = make_sample_data(tmp_path)

    config["dataType"] = "tabular"

    csv_data = {
        "tabularData": {
            "sources": [{
                "root": "dataroot",
                "path": data_dir / "file_*.txt"
            }],
            "labelIndex": 42
        },
    }

    config.update(csv_data)

    # Now construct the object and see what we have
    juneberry.DATA_ROOT = tmp_path
    ds = DatasetConfig(config, Path('.'))
    assert ds.is_tabular_type()

    paths, index = ds.get_resolved_tabular_source_paths_and_labels()

    assert len(paths) == len(file_names)

    # The glob operator may provide them in a different order so we just need
    # to make sure it is in there somewhere
    for file_name in file_names:
        assert Path(tmp_path) / data_dir / file_name in paths

    assert index == 42
Пример #13
0
def test_dry_run():
    logging.basicConfig(level=logging.INFO)
    dsc = test_data_set.make_basic_config()
    data_set_config = DatasetConfig(dsc)
    tc = test_training_config.make_basic_config()
    training_config = TrainingConfig("simple_model", tc)

    trainer = EpochTrainerHarness(training_config,
                                  data_set_config,
                                  dry_run=True)

    trainer.train_model()
    assert trainer.setup_calls == [0]
    assert trainer.dry_run_calls == [1]

    print("Done")
Пример #14
0
def test_generate_image_list():
    # Just replace listdir
    os.listdir = mock_list_image_dir

    data_set_struct = make_basic_data_set_image_config()

    data_set_config = DatasetConfig(data_set_struct)
    dm = jbfs.DataManager({})

    train_list, val_list = jb_data.generate_image_list('data_root',
                                                       data_set_config, None,
                                                       dm)
    assert len(train_list) == 12
    assert len(val_list) == 0

    assert_correct_list(train_list, range(6), range(6))
Пример #15
0
def test_load_tabular_data(tmp_path):
    correct = fill_tabular_tempdir(tmp_path)
    juneberry.WORKSPACE_ROOT = Path(tmp_path) / 'myworkspace'
    juneberry.DATA_ROOT = Path(tmp_path) / 'mydataroot'

    data_set_struct = make_basic_data_set_tabular_config()
    data_set_config = DatasetConfig(data_set_struct,
                                    Path(tmp_path) / 'myrelative')

    train_list, val_list = jb_data.load_tabular_data(None, data_set_config)

    # THe sample data is three files each with 4 sample with 2 in each class.
    # THe default validation split is 2.  So 3 * 4 / 2 = 6 per list
    assert len(train_list) == 12
    assert len(val_list) == 0

    # Make sure that evert returned value is in the results.
    for data, label in train_list:
        assert correct[int(label)][int(data[0])] == int(data[1])
        del correct[int(label)][int(data[0])]
Пример #16
0
def test_config_basics():
    config = make_basic_config()

    ds = DatasetConfig(config, Path("."))
    assert ds.is_image_type() is True
    assert ds.data_type == DataType.IMAGE
    assert ds.task_type == TaskType.CLASSIFICATION
    assert ds.num_model_classes == 4
    assert ds.description == "Unit test"
    assert ds.timestamp == "never"
    assert ds.format_version == "3.2.0"

    assert len(ds.label_names) == 2
    assert ds.label_names[0] == "frodo"
    assert ds.label_names[1] == "sam"

    config = make_basic_config(True, False)
    ds = DatasetConfig(config, Path("."))
    assert ds.data_type == DataType.IMAGE
    assert ds.task_type == TaskType.OBJECTDETECTION

    config = make_basic_config(False)
    ds = DatasetConfig(config, Path("."))
    assert ds.data_type == DataType.TABULAR
Пример #17
0
def test_epoch_trainer():
    """
    >> 0 setup
    >> 1 start_epoch_training
    >> 2 process_batch
    >> 3 update_metrics
    >> 4 update_model
    >> 5 process_batch
    >> 6 update_metrics
    >> 7 update_model
    >> 8 process_batch
    >> 9 update_metrics
    >> 10 update_model
    >> 11 summarize_metrics
    >> 12 start_epoch_evaluation
    >> 13 process_batch
    >> 14 update_metrics
    >> 15 process_batch
    >> 16 update_metrics
    >> 17 summarize_metrics
    >> 18 end_epoch
    >> 19 start_epoch_training
    >> 20 process_batch
    >> 21 update_metrics
    >> 22 update_model
    >> 23 process_batch
    >> 24 update_metrics
    >> 25 update_model
    >> 26 process_batch
    >> 27 update_metrics
    >> 28 update_model
    >> 29 summarize_metrics
    >> 30 start_epoch_evaluation
    >> 31 process_batch
    >> 32 update_metrics
    >> 33 process_batch
    >> 34 update_metrics
    >> 35 summarize_metrics
    >> 36 end_epoch
    >> 37 start_epoch_training
    >> 38 process_batch
    >> 39 update_metrics
    >> 40 update_model
    >> 41 process_batch
    >> 42 update_metrics
    >> 43 update_model
    >> 44 process_batch
    >> 45 update_metrics
    >> 46 update_model
    >> 47 summarize_metrics
    >> 48 start_epoch_evaluation
    >> 49 process_batch
    >> 50 update_metrics
    >> 51 process_batch
    >> 52 update_metrics
    >> 53 summarize_metrics
    >> 54 end_epoch
    >> 55 finalize_results
    """

    print("Starting")
    logging.basicConfig(level=logging.INFO)
    dsc = test_data_set.make_basic_config()
    data_set_config = DatasetConfig(dsc)
    tc = test_training_config.make_basic_config()
    training_config = TrainingConfig("simple_model", tc)

    trainer = EpochTrainerHarness(training_config, data_set_config)

    trainer.train_model()
    assert trainer.setup_calls == [0]
    assert trainer.start_epoch_training_calls == [1, 19, 37]

    # Check all the results.  Sequencing is important
    assert trainer.process_batch_calls == [
        2, 5, 8, 13, 15, 20, 23, 26, 31, 33, 38, 41, 44, 49, 51
    ]
    assert trainer.start_epoch_evaluation_calls == [12, 30, 48]
    assert trainer.update_metrics_calls == [
        3, 6, 9, 14, 16, 21, 24, 27, 32, 34, 39, 42, 45, 50, 52
    ]
    assert trainer.update_model_calls == [4, 7, 10, 22, 25, 28, 40, 43, 46]
    assert trainer.summarize_metrics_calls == [11, 17, 29, 35, 47, 53]
    assert trainer.checkpoint_calls == [18, 36, 54]
    assert trainer.finalize_results_calls == [55]
    assert trainer.close_calls == [56]

    trainer.timer.log_metrics()

    print("Done")