def setUp(self): self.tmp_output_dir = tempfile.mkdtemp() test_learncurve_config = TEST_CONFIGS_DIR.joinpath( 'test_learncurve_config.ini') # Now we want a copy (of the changed version) to use for tests # since this is what the test data was made with self.tmp_config_dir = tempfile.mkdtemp() self.tmp_config_path = Path( self.tmp_config_dir).joinpath('tmp_test_learncurve_config.ini') shutil.copy(test_learncurve_config, self.tmp_config_path) # rewrite config so it points to data for testing + temporary output dirs config = ConfigParser() config.read(self.tmp_config_path) test_data_vds_path = list(TEST_DATA_DIR.glob('vds'))[0] for stem in ['train', 'test', 'val']: vds_path = list(test_data_vds_path.glob(f'*.{stem}.vds.json')) self.assertTrue(len(vds_path) == 1) vds_path = vds_path[0] config['LEARNCURVE'][f'{stem}_vds_path'] = str(vds_path) config['PREP']['output_dir'] = str(self.tmp_output_dir) config['PREP']['data_dir'] = str( TEST_DATA_DIR.joinpath('cbins', 'gy6or6', '032312')) config['LEARNCURVE']['root_results_dir'] = str(self.tmp_output_dir) with open(self.tmp_config_path, 'w') as fp: config.write(fp)
def _add_dirs_to_config_and_save_as_tmp(self, config_file): """helper functions called by unit tests to add directories that actually exist to avoid spurious NotADirectory errors""" config = ConfigParser() config.read(config_file) if config.has_section('PREP'): config['PREP']['data_dir'] = self.tmp_data_dir config['PREP']['output_dir'] = self.tmp_data_output_dir if config.has_section('TRAIN'): config['TRAIN']['train_vds_path'] = self.tmp_train_vds_path config['TRAIN']['val_vds_path'] = self.tmp_val_vds_path config['TRAIN']['root_results_dir'] = self.tmp_root_dir config['TRAIN']['results_dir_made_by_main_script'] = self.tmp_results_dir if config.has_section('LEARNCURVE'): config['LEARNCURVE']['train_vds_path'] = self.tmp_train_vds_path config['LEARNCURVE']['val_vds_path'] = self.tmp_val_vds_path config['LEARNCURVE']['test_vds_path'] = self.tmp_test_vds_path config['LEARNCURVE']['root_results_dir'] = self.tmp_root_dir config['LEARNCURVE']['results_dir_made_by_main_script'] = self.tmp_results_dir if config.has_section('PREDICT'): config['PREDICT']['checkpoint_path'] = self.tmp_checkpoint_dir config['PREDICT']['train_vds_path'] = self.tmp_train_vds_path config['PREDICT']['predict_vds_path'] = self.tmp_predict_vds_path config['PREDICT']['spect_scaler_path'] = self.tmp_spect_scaler_path file_obj = tempfile.NamedTemporaryFile(prefix='config', suffix='.ini', mode='w', dir=self.tmp_config_dir, delete=False) with file_obj as config_file_out: config.write(config_file_out) return os.path.abspath(file_obj.name)
def test_invalid_network_option_raises(self): test_learncurve_config = os.path.join(TEST_CONFIGS_PATH, 'test_learncurve_config.ini') tmp_config_file = self._add_dirs_to_config_and_save_as_tmp(test_learncurve_config) config = ConfigParser() config.read(tmp_config_file) config['TweetyNet']['bungalow'] = '12' with open(tmp_config_file, 'w') as rewrite: config.write(rewrite) with self.assertRaises(ValueError): vak.config.parse_config(tmp_config_file)
def test_defined_sections_not_None(self): test_configs = glob(os.path.join(TEST_CONFIGS_PATH, 'test_*_config.ini')) for test_config in test_configs: tmp_config_file = self._add_dirs_to_config_and_save_as_tmp(test_config) config = ConfigParser() config.read(tmp_config_file) config_obj = vak.config.parse_config(tmp_config_file) for section in config.sections(): if section in self.section_to_attr_map: # check sections that any config.ini file can have, non-network specific attr_name = self.section_to_attr_map[section] self.assertTrue(getattr(config_obj, attr_name) is not None) elif section.lower() in config_obj.networks: # check network specific sections self.assertTrue(getattr(config_obj.networks, section) is not None)
def test_network_sections_match_config(self): test_configs = glob(os.path.join(TEST_CONFIGS_PATH, 'test_*_config.ini')) NETWORKS = vak.models._load() available_net_names = [net_name for net_name in NETWORKS.keys()] for test_config in test_configs: tmp_config_file = self._add_dirs_to_config_and_save_as_tmp(test_config) config = ConfigParser() config.read(tmp_config_file) config_obj = vak.config.parse_config(tmp_config_file) for section in config.sections(): if section in available_net_names: net_name_to_check = section self.assertTrue(net_name_to_check in config_obj.networks) # check network specific sections net_config = config_obj.networks[net_name_to_check] self.assertTrue(field in net_config for field in NETWORKS[section].Config._fields)