Ejemplo n.º 1
0
    def setUp(self):
        self.tmp_output_dir = tempfile.mkdtemp()

        test_learncurve_config = TEST_CONFIGS_DIR.joinpath(
            'test_learncurve_config.ini')
        # Now we want a copy (of the changed version) to use for tests
        # since this is what the test data was made with
        self.tmp_config_dir = tempfile.mkdtemp()
        self.tmp_config_path = Path(
            self.tmp_config_dir).joinpath('tmp_test_learncurve_config.ini')
        shutil.copy(test_learncurve_config, self.tmp_config_path)

        # rewrite config so it points to data for testing + temporary output dirs
        config = ConfigParser()
        config.read(self.tmp_config_path)
        test_data_vds_path = list(TEST_DATA_DIR.glob('vds'))[0]
        for stem in ['train', 'test', 'val']:
            vds_path = list(test_data_vds_path.glob(f'*.{stem}.vds.json'))
            self.assertTrue(len(vds_path) == 1)
            vds_path = vds_path[0]
            config['LEARNCURVE'][f'{stem}_vds_path'] = str(vds_path)

        config['PREP']['output_dir'] = str(self.tmp_output_dir)
        config['PREP']['data_dir'] = str(
            TEST_DATA_DIR.joinpath('cbins', 'gy6or6', '032312'))
        config['LEARNCURVE']['root_results_dir'] = str(self.tmp_output_dir)
        with open(self.tmp_config_path, 'w') as fp:
            config.write(fp)
Ejemplo n.º 2
0
    def _add_dirs_to_config_and_save_as_tmp(self, config_file):
        """helper functions called by unit tests to add directories
        that actually exist to avoid spurious NotADirectory errors"""
        config = ConfigParser()
        config.read(config_file)

        if config.has_section('PREP'):
            config['PREP']['data_dir'] = self.tmp_data_dir
            config['PREP']['output_dir'] = self.tmp_data_output_dir

        if config.has_section('TRAIN'):
            config['TRAIN']['train_vds_path'] = self.tmp_train_vds_path
            config['TRAIN']['val_vds_path'] = self.tmp_val_vds_path
            config['TRAIN']['root_results_dir'] = self.tmp_root_dir
            config['TRAIN']['results_dir_made_by_main_script'] = self.tmp_results_dir

        if config.has_section('LEARNCURVE'):
            config['LEARNCURVE']['train_vds_path'] = self.tmp_train_vds_path
            config['LEARNCURVE']['val_vds_path'] = self.tmp_val_vds_path
            config['LEARNCURVE']['test_vds_path'] = self.tmp_test_vds_path
            config['LEARNCURVE']['root_results_dir'] = self.tmp_root_dir
            config['LEARNCURVE']['results_dir_made_by_main_script'] = self.tmp_results_dir

        if config.has_section('PREDICT'):
            config['PREDICT']['checkpoint_path'] = self.tmp_checkpoint_dir
            config['PREDICT']['train_vds_path'] = self.tmp_train_vds_path
            config['PREDICT']['predict_vds_path'] = self.tmp_predict_vds_path
            config['PREDICT']['spect_scaler_path'] = self.tmp_spect_scaler_path

        file_obj = tempfile.NamedTemporaryFile(prefix='config', suffix='.ini', mode='w',
                                               dir=self.tmp_config_dir, delete=False)
        with file_obj as config_file_out:
            config.write(config_file_out)
        return os.path.abspath(file_obj.name)
Ejemplo n.º 3
0
 def test_invalid_network_option_raises(self):
     test_learncurve_config = os.path.join(TEST_CONFIGS_PATH,
                                           'test_learncurve_config.ini')
     tmp_config_file = self._add_dirs_to_config_and_save_as_tmp(test_learncurve_config)
     config = ConfigParser()
     config.read(tmp_config_file)
     config['TweetyNet']['bungalow'] = '12'
     with open(tmp_config_file, 'w') as rewrite:
         config.write(rewrite)
     with self.assertRaises(ValueError):
         vak.config.parse_config(tmp_config_file)
Ejemplo n.º 4
0
 def test_defined_sections_not_None(self):
     test_configs = glob(os.path.join(TEST_CONFIGS_PATH,
                                      'test_*_config.ini'))
     for test_config in test_configs:
         tmp_config_file = self._add_dirs_to_config_and_save_as_tmp(test_config)
         config = ConfigParser()
         config.read(tmp_config_file)
         config_obj = vak.config.parse_config(tmp_config_file)
         for section in config.sections():
             if section in self.section_to_attr_map:
                 # check sections that any config.ini file can have, non-network specific
                 attr_name = self.section_to_attr_map[section]
                 self.assertTrue(getattr(config_obj, attr_name) is not None)
             elif section.lower() in config_obj.networks:
                 # check network specific sections
                 self.assertTrue(getattr(config_obj.networks, section) is not None)
Ejemplo n.º 5
0
 def test_network_sections_match_config(self):
     test_configs = glob(os.path.join(TEST_CONFIGS_PATH,
                                      'test_*_config.ini'))
     NETWORKS = vak.models._load()
     available_net_names = [net_name for net_name in NETWORKS.keys()]
     for test_config in test_configs:
         tmp_config_file = self._add_dirs_to_config_and_save_as_tmp(test_config)
         config = ConfigParser()
         config.read(tmp_config_file)
         config_obj = vak.config.parse_config(tmp_config_file)
         for section in config.sections():
             if section in available_net_names:
                 net_name_to_check = section
                 self.assertTrue(net_name_to_check in config_obj.networks)
                 # check network specific sections
                 net_config = config_obj.networks[net_name_to_check]
                 self.assertTrue(field in net_config for field in NETWORKS[section].Config._fields)