Exemple #1
0
def test_generate_image_validation_split():
    os.listdir = mock_list_image_dir

    data_set_struct = make_basic_data_set_image_config()
    data_set_config = DatasetConfig(data_set_struct)

    train_struct = test_training_config.make_basic_config()
    train_struct['validation'] = {
        "algorithm": "randomFraction",
        "arguments": {
            "seed": 1234,
            "fraction": 0.3333333
        }
    }
    train_config = TrainingConfig('', train_struct)
    dm = jbfs.DataManager({})

    train_list, val_list = jb_data.generate_image_list('data_root',
                                                       data_set_config,
                                                       train_config, dm)
    assert len(train_list) == 8
    assert len(val_list) == 4

    # NOTE: Another fragile secret we know is the order from the validation is is reversed
    assert_correct_list(train_list, [1, 2, 4, 5], [1, 2, 3, 4])
    assert_correct_list(val_list, [3, 0], [5, 0])
Exemple #2
0
def test_generate_image_sample_fraction():
    # If we pass in sampling count we should just get those
    # We know how the internal randomizer works.  We know it uses random.sample on both
    # sets in order.  This is a secret and fragile to this test.
    # With a seed of 1234 and two pulls of sampling with a count of 2, it pulls [3,0] and [0,5]
    os.listdir = mock_list_image_dir

    data_set_struct = make_basic_data_set_image_config()
    data_set_struct.update(
        make_sample_stanza("randomFraction", {
            'seed': 1234,
            'fraction': 0.3333333333
        }))

    data_set_config = DatasetConfig(data_set_struct)
    dm = jbfs.DataManager({})

    train_list, val_list = jb_data.generate_image_list('data_root',
                                                       data_set_config, None,
                                                       dm)
    assert len(train_list) == 4
    assert len(val_list) == 0

    # Make sure they are in this order
    assert_correct_list(train_list, [3, 0], [0, 5])
Exemple #3
0
def test_data_manager():

    data_set_config = {}
    dm = jbfs.DataManager(data_set_config)
    assert dm.version_path == Path()
    dm = jbfs.DataManager(data_set_config, version='009')
    assert dm.version_path == Path('009')

    data_set_config['dataSetPath'] = 'testDataSet'
    dm = jbfs.DataManager(data_set_config)
    assert dm.version_path == Path('testDataSet')
    dm = jbfs.DataManager(data_set_config, version='009')
    assert dm.version_path == Path('testDataSet') / '009'

    category = 'cars'
    assert dm.get_directory_path(
        category) == Path('testDataSet') / '009' / 'cars'

    image = 'miata.png'
    assert dm.get_file_path(
        category, image) == Path('testDataSet') / '009' / 'cars' / 'miata.png'

    data_set_config['imageData'] = {}
    data_set_config['imageData']['properties'] = {}
    data_set_config['imageData']['properties']['dimensions'] = '64,64'
    data_set_config['imageData']['properties']['colorspace'] = 'gray'
    dm = jbfs.DataManager(data_set_config, version='009')

    assert dm.check_cache('./test', category, image) == ('miata.png', False)

    temp_dir = Path.cwd(
    ) / 'test' / 'cache' / 'testDataSet' / '009' / '64x64_gray' / 'cars'
    temp_image = temp_dir / 'miata.png'

    if not temp_dir.exists():
        temp_dir.mkdir(parents=True)
    temp_image.touch()

    import shutil
    shutil.rmtree(Path.cwd() / 'test' / 'cache')
Exemple #4
0
def test_generate_image_list():
    # Just replace listdir
    os.listdir = mock_list_image_dir

    data_set_struct = make_basic_data_set_image_config()

    data_set_config = DatasetConfig(data_set_struct)
    dm = jbfs.DataManager({})

    train_list, val_list = jb_data.generate_image_list('data_root',
                                                       data_set_config, None,
                                                       dm)
    assert len(train_list) == 12
    assert len(val_list) == 0

    assert_correct_list(train_list, range(6), range(6))
    def setup(self):
        # Construct helper objects
        self.data_manager = jbfs.DataManager(self.data_set_config.config, self.data_version)

        if juneberry.TENSORBOARD_ROOT:
            self.tb_mgr = jbtb.TensorBoardManager(juneberry.TENSORBOARD_ROOT, self.model_manager)

        pyt_utils.set_seeds(self.training_config.seed)

        self.setup_hardware()
        self.setup_data_loaders()
        self.setup_model()

        self.loss_function = pyt_utils.make_criterion(self.pytorch_options, self.binary)
        self.optimizer = pyt_utils.make_optimizer(self.pytorch_options, self.model)
        self.lr_scheduler = pyt_utils.make_lr_scheduler(self.pytorch_options, self.optimizer)
        self.accuracy_function = pyt_utils.make_accuracy(self.pytorch_options, self.binary)
        self.setup_acceptance_checker()

        self.num_batches = len(self.training_iterable)

        self.history = {'loss': [], 'accuracy': [], 'val_loss': [], 'val_accuracy': [], 'epoch_duration': [], 'lr': []}