def test_generate_image_validation_split(): os.listdir = mock_list_image_dir data_set_struct = make_basic_data_set_image_config() data_set_config = DatasetConfig(data_set_struct) train_struct = test_training_config.make_basic_config() train_struct['validation'] = { "algorithm": "randomFraction", "arguments": { "seed": 1234, "fraction": 0.3333333 } } train_config = TrainingConfig('', train_struct) dm = jbfs.DataManager({}) train_list, val_list = jb_data.generate_image_list('data_root', data_set_config, train_config, dm) assert len(train_list) == 8 assert len(val_list) == 4 # NOTE: Another fragile secret we know is the order from the validation is is reversed assert_correct_list(train_list, [1, 2, 4, 5], [1, 2, 3, 4]) assert_correct_list(val_list, [3, 0], [5, 0])
def test_generate_image_sample_fraction(): # If we pass in sampling count we should just get those # We know how the internal randomizer works. We know it uses random.sample on both # sets in order. This is a secret and fragile to this test. # With a seed of 1234 and two pulls of sampling with a count of 2, it pulls [3,0] and [0,5] os.listdir = mock_list_image_dir data_set_struct = make_basic_data_set_image_config() data_set_struct.update( make_sample_stanza("randomFraction", { 'seed': 1234, 'fraction': 0.3333333333 })) data_set_config = DatasetConfig(data_set_struct) dm = jbfs.DataManager({}) train_list, val_list = jb_data.generate_image_list('data_root', data_set_config, None, dm) assert len(train_list) == 4 assert len(val_list) == 0 # Make sure they are in this order assert_correct_list(train_list, [3, 0], [0, 5])
def test_data_manager(): data_set_config = {} dm = jbfs.DataManager(data_set_config) assert dm.version_path == Path() dm = jbfs.DataManager(data_set_config, version='009') assert dm.version_path == Path('009') data_set_config['dataSetPath'] = 'testDataSet' dm = jbfs.DataManager(data_set_config) assert dm.version_path == Path('testDataSet') dm = jbfs.DataManager(data_set_config, version='009') assert dm.version_path == Path('testDataSet') / '009' category = 'cars' assert dm.get_directory_path( category) == Path('testDataSet') / '009' / 'cars' image = 'miata.png' assert dm.get_file_path( category, image) == Path('testDataSet') / '009' / 'cars' / 'miata.png' data_set_config['imageData'] = {} data_set_config['imageData']['properties'] = {} data_set_config['imageData']['properties']['dimensions'] = '64,64' data_set_config['imageData']['properties']['colorspace'] = 'gray' dm = jbfs.DataManager(data_set_config, version='009') assert dm.check_cache('./test', category, image) == ('miata.png', False) temp_dir = Path.cwd( ) / 'test' / 'cache' / 'testDataSet' / '009' / '64x64_gray' / 'cars' temp_image = temp_dir / 'miata.png' if not temp_dir.exists(): temp_dir.mkdir(parents=True) temp_image.touch() import shutil shutil.rmtree(Path.cwd() / 'test' / 'cache')
def test_generate_image_list(): # Just replace listdir os.listdir = mock_list_image_dir data_set_struct = make_basic_data_set_image_config() data_set_config = DatasetConfig(data_set_struct) dm = jbfs.DataManager({}) train_list, val_list = jb_data.generate_image_list('data_root', data_set_config, None, dm) assert len(train_list) == 12 assert len(val_list) == 0 assert_correct_list(train_list, range(6), range(6))
def setup(self): # Construct helper objects self.data_manager = jbfs.DataManager(self.data_set_config.config, self.data_version) if juneberry.TENSORBOARD_ROOT: self.tb_mgr = jbtb.TensorBoardManager(juneberry.TENSORBOARD_ROOT, self.model_manager) pyt_utils.set_seeds(self.training_config.seed) self.setup_hardware() self.setup_data_loaders() self.setup_model() self.loss_function = pyt_utils.make_criterion(self.pytorch_options, self.binary) self.optimizer = pyt_utils.make_optimizer(self.pytorch_options, self.model) self.lr_scheduler = pyt_utils.make_lr_scheduler(self.pytorch_options, self.optimizer) self.accuracy_function = pyt_utils.make_accuracy(self.pytorch_options, self.binary) self.setup_acceptance_checker() self.num_batches = len(self.training_iterable) self.history = {'loss': [], 'accuracy': [], 'val_loss': [], 'val_accuracy': [], 'epoch_duration': [], 'lr': []}