コード例 #1
0
ファイル: test_9a_Predictor.py プロジェクト: yangxhcaf/sdmdl
class PredictorTestCase(unittest.TestCase):
    def setUp(self):
        self.root = os.path.abspath(os.path.join(
            os.path.dirname(__file__))) + '/test_data'
        self.oh = Occurrences(self.root + '/root')
        self.oh.validate_occurrences()
        self.oh.species_dictionary()
        self.gh = GIS(self.root + '/root')
        self.gh.validate_gis()
        self.gh.validate_tif()
        self.gh.define_output()
        self.ch = Config(self.root + '/root', self.oh, self.gh)
        self.ch.search_config()
        self.ch.read_yaml()
        self.verbose = False
        self.p = Predictor(self.oh, self.gh, self.ch, self.verbose)

    def test__init__(self):
        self.assertEqual(self.p.oh, self.oh)
        self.assertEqual(self.p.gh, self.gh)
        self.assertEqual(self.p.ch, self.ch)
        self.assertEqual(self.p.verbose, self.verbose)

    def test_prep_prediction_data(self):
        myarray, index_minb1 = self.p.prep_prediction_data()
        myarray_truth = gdal.Open(
            self.root +
            '/root/gis/stack/stacked_env_variables.tif').ReadAsArray()
        empty_map = rasterio.open(self.root +
                                  '/root/gis/layers/empty_land_map.tif')
        empty_map = empty_map.read(1)
        min_empty_map = np.min(empty_map)
        index_minb1_truth = np.where(empty_map == min_empty_map)
        self.assertEqual(myarray.tolist(), myarray_truth.tolist())
        index_minb1 = [x.tolist() for x in index_minb1]
        index_minb1_truth = [x.tolist() for x in index_minb1_truth]
        self.assertEqual(index_minb1, index_minb1_truth)

    def notest_predict_distribution(self):
        myarray, index_minb1 = self.p.prep_prediction_data()
        new_band = self.p.predict_distribution(self.oh.name[0], myarray,
                                               index_minb1)
        with np.load(self.root + '/predictor/new_band.npz') as data:
            new_band_truth = data[list(data.keys())[0]]
            np.testing.assert_array_equal(new_band, new_band_truth)
コード例 #2
0
ファイル: test_0_Occurrences.py プロジェクト: yangxhcaf/sdmdl
class OccurrencesTestCase(unittest.TestCase):
    """Test cases for Occurrence class."""
    def setUp(self):
        self.root = (os.path.abspath(os.path.join(os.path.dirname(__file__))) +
                     '/test_data/root').replace('\\', '/')

    def test__init__(self):
        self.oh = Occurrences(self.root)
        self.assertEqual(self.oh.root, self.root)
        self.assertEqual(self.oh.length, 0)
        self.assertEqual(self.oh.path, [])
        self.assertEqual(self.oh.name, [])
        self.assertEqual(self.oh.spec_dict, {})

    def test_validate_occurrences(self):
        self.oh = Occurrences(self.root)
        self.oh.validate_occurrences()
        self.assertEqual(self.oh.length, 2)
        self.assertEqual(self.oh.path, [
            self.root + '/occurrences/arachis_duranensis.csv',
            self.root + '/occurrences/solanum_bukasovii.csv'
        ])
        self.assertEqual(self.oh.name,
                         ['arachis_duranensis', 'solanum_bukasovii'])

    def test_species_dictionary(self):
        self.oh = Occurrences(self.root)
        self.oh.validate_occurrences()
        self.oh.species_dictionary()
        self.assertIsInstance(self.oh.spec_dict, dict)
        self.assertEqual(list(self.oh.spec_dict.keys()),
                         ['arachis_duranensis', 'solanum_bukasovii'])
        self.assertEqual(list(self.oh.spec_dict['arachis_duranensis']),
                         ['dLon', 'dLat'])
        spec_dict_truth = pd.read_csv(self.root +
                                      '/occurrences/solanum_bukasovii.csv')[[
                                          'decimalLongitude', 'decimalLatitude'
                                      ]]
        self.assertEqual(
            self.oh.spec_dict['solanum_bukasovii'].values.tolist(),
            spec_dict_truth.to_numpy().tolist())
コード例 #3
0
ファイル: test_3_PresenceMap.py プロジェクト: yangxhcaf/sdmdl
class PresenceMapTestCase(unittest.TestCase):
    def setUp(self):
        self.root = os.path.abspath(os.path.join(
            os.path.dirname(__file__))) + '/test_data'
        self.oh = Occurrences(self.root + '/root')
        self.oh.validate_occurrences()
        self.oh.species_dictionary()
        self.gh = GIS(self.root + '/root')
        self.gh.validate_gis()
        self.gh.validate_tif()
        self.gh.define_output()
        self.verbose = False

        self.cpm = PresenceMap(self.oh, self.gh, self.verbose)

    def test__init__(self):
        self.assertEqual(self.cpm.oh, self.oh)
        self.assertEqual(self.cpm.gh, self.gh)
        self.assertEqual(self.cpm.verbose, self.verbose)

    def test_create_presence_map(self):
        shutil.move(
            self.root +
            '/root/gis/layers/non-scaled/presence/arachis_duranensis_presence_map.tif',
            self.root +
            '/root/gis/layers/non-scaled/presence/true_arachis_duranensis_presence_map.tif'
        )
        shutil.move(
            self.root +
            '/root/gis/layers/non-scaled/presence/solanum_bukasovii_presence_map.tif',
            self.root +
            '/root/gis/layers/non-scaled/presence/true_solanum_bukasovii_presence_map.tif'
        )
        self.assertFalse(
            os.path.isfile(
                self.root +
                '/root/gis/layers/non-scaled/presence/arachis_duranensis_presence_map.tif'
            ))
        self.assertFalse(
            os.path.isfile(
                self.root +
                '/root/gis/layers/non-scaled/presence/solanum_bukasovii_presence_map.tif'
            ))
        self.cpm.create_presence_map()
        result_a = rasterio.open(
            self.root +
            '/root/gis/layers/non-scaled/presence/arachis_duranensis_presence_map.tif'
        )
        result_b = rasterio.open(
            self.root +
            '/root/gis/layers/non-scaled/presence/solanum_bukasovii_presence_map.tif'
        )
        truth_a = rasterio.open(
            self.root + '/presence_map/arachis_duranensis_presence_map.tif')
        truth_b = rasterio.open(
            self.root + '/presence_map/solanum_bukasovii_presence_map.tif')
        self.assertTrue(
            os.path.isfile(
                self.root +
                '/root/gis/layers/non-scaled/presence/arachis_duranensis_presence_map.tif'
            ))
        self.assertTrue(
            os.path.isfile(
                self.root +
                '/root/gis/layers/non-scaled/presence/solanum_bukasovii_presence_map.tif'
            ))
        self.assertEqual(result_a.read(1).tolist(), truth_a.read(1).tolist())
        self.assertEqual(result_b.read(1).tolist(), truth_b.read(1).tolist())
        [raster.close() for raster in [result_a, result_b, truth_a, truth_b]]
        os.remove(
            self.root +
            '/root/gis/layers/non-scaled/presence/arachis_duranensis_presence_map.tif'
        )
        os.remove(
            self.root +
            '/root/gis/layers/non-scaled/presence/solanum_bukasovii_presence_map.tif'
        )
        shutil.move(
            self.root +
            '/root/gis/layers/non-scaled/presence/true_arachis_duranensis_presence_map.tif',
            self.root +
            '/root/gis/layers/non-scaled/presence/arachis_duranensis_presence_map.tif'
        )
        shutil.move(
            self.root +
            '/root/gis/layers/non-scaled/presence/true_solanum_bukasovii_presence_map.tif',
            self.root +
            '/root/gis/layers/non-scaled/presence/solanum_bukasovii_presence_map.tif'
        )
コード例 #4
0
class TrainingDataTestCase(unittest.TestCase):
    def setUp(self):
        self.root = os.path.abspath(os.path.join(
            os.path.dirname(__file__))) + '/test_data'
        self.oh = Occurrences(self.root + '/root')
        self.oh.validate_occurrences()
        self.oh.species_dictionary()
        self.gh = GIS(self.root + '/root')
        self.gh.validate_gis()
        self.gh.validate_tif()
        self.gh.define_output()
        self.verbose = False

        self.ctd = TrainingData(self.oh, self.gh, self.verbose)

    def test__init__(self):
        self.assertEqual(self.ctd.oh, self.oh)
        self.assertEqual(self.ctd.gh, self.gh)
        self.assertEqual(self.ctd.verbose, self.verbose)

    def test_prep_training_df(self):
        src = rasterio.open(self.root +
                            '/root/gis/stack/stacked_env_variables.tif')
        inRas = gdal.Open(self.root +
                          '/root/gis/stack/stacked_env_variables.tif')
        spec, ppa, long, lati, row, col, myarray, mean_std = self.ctd.prep_training_df(
            src, inRas, self.oh.name[0])
        ppa_truth = np.load(self.root + '/training_data/ppa.npy')
        long_truth = np.load(self.root + '/training_data/long.npy')
        lati_truth = np.load(self.root + '/training_data/lati.npy')
        row_truth = np.load(self.root + '/training_data/row.npy')
        col_truth = np.load(self.root + '/training_data/col.npy')
        mean_std_truth = np.load(self.root + '/training_data/mean_std.npy')
        self.assertEqual(spec, self.oh.name[0])
        self.assertEqual(ppa.to_numpy().tolist(), ppa_truth.tolist())
        self.assertEqual(long.tolist(), long_truth.tolist())
        self.assertEqual(lati.tolist(), lati_truth.tolist())
        self.assertEqual(row, row_truth.tolist())
        self.assertEqual(col, col_truth.tolist())
        self.assertEqual(myarray.tolist(), inRas.ReadAsArray().tolist())
        self.assertEqual(mean_std.tolist(), mean_std_truth.tolist())
        src.close()

    def test_create_training_df(self):
        os.remove(self.root +
                  '/root/spec_ppa_env/arachis_duranensis_env_dataframe.csv')
        os.remove(self.root +
                  '/root/spec_ppa_env/solanum_bukasovii_env_dataframe.csv')
        self.assertFalse(
            os.path.isfile(
                self.root +
                '/root/spec_ppa_env/arachis_duranensis_env_dataframe.csv'))
        self.assertFalse(
            os.path.isfile(
                self.root +
                '/root/spec_ppa_env/solanum_bukasovii_env_dataframe.csv'))
        self.ctd.create_training_df()
        self.assertTrue(
            os.path.isfile(
                self.root +
                '/root/spec_ppa_env/arachis_duranensis_env_dataframe.csv'))
        self.assertTrue(
            os.path.isfile(
                self.root +
                '/root/spec_ppa_env/solanum_bukasovii_env_dataframe.csv'))
        result_a = pd.read_csv(
            self.root +
            '/root/spec_ppa_env/arachis_duranensis_env_dataframe.csv')
        result_b = pd.read_csv(
            self.root +
            '/root/spec_ppa_env/solanum_bukasovii_env_dataframe.csv')
        truth_a = pd.read_csv(
            self.root + '/training_data/arachis_duranensis_env_dataframe.csv')
        truth_b = pd.read_csv(
            self.root + '/training_data/solanum_bukasovii_env_dataframe.csv')
        self.assertEqual(list(result_a.columns), list(truth_a.columns))
        self.assertEqual(list(result_a.columns), list(truth_a.columns))
        self.assertEqual(result_a.to_numpy().tolist(),
                         truth_a.to_numpy().tolist())
        self.assertEqual(result_b.to_numpy().tolist(),
                         truth_b.to_numpy().tolist())
コード例 #5
0
ファイル: test_9_Trainer.py プロジェクト: yangxhcaf/sdmdl
class TrainerTestCase(unittest.TestCase):

    def setUp(self):
        self.root = os.path.abspath(os.path.join(os.path.dirname(__file__))) + '/test_data'
        self.oh = Occurrences(self.root + '/root')
        self.oh.validate_occurrences()
        self.oh.species_dictionary()
        self.gh = GIS(self.root + '/root')
        self.gh.validate_gis()
        self.gh.validate_tif()
        self.gh.define_output()
        self.ch = Config(self.root + '/root', self.oh, self.gh)
        self.ch.search_config()
        self.ch.read_yaml()
        self.verbose = False
        self.t = Trainer(self.oh, self.gh, self.ch, self.verbose)

    def test__init__(self):
        self.assertEqual(self.t.oh, self.oh)
        self.assertEqual(self.t.gh, self.gh)
        self.assertEqual(self.t.ch, self.ch)
        self.assertEqual(self.t.verbose, self.verbose)
        self.assertEqual(self.t.spec, '')
        self.assertEqual(self.t.variables, [])
        self.assertEqual(self.t.test_loss, [])
        self.assertEqual(self.t.test_acc, [])
        self.assertEqual(self.t.test_AUC, [])
        self.assertEqual(self.t.test_tpr, [])
        self.assertEqual(self.t.test_uci, [])
        self.assertEqual(self.t.test_lci, [])
        self.assertEqual(self.t.best_model_auc, [0])
        self.assertEqual(self.t.occ_len, 0)
        self.assertEqual(self.t.abs_len, 0)
        self.assertEqual(self.t.random_seed, self.ch.random_seed)
        self.assertEqual(self.t.batch, self.ch.batchsize)
        self.assertEqual(self.t.epoch, self.ch.epoch)
        self.assertEqual(self.t.model_layers, self.ch.model_layers)
        self.assertEqual(self.t.model_dropout, self.ch.model_dropout)

    def test_create_eval(self):
        os.remove(self.root + '/root/results/_DNN_performance/DNN_eval.txt')
        self.assertFalse(os.path.isfile(self.root + '/root/results/_DNN_performance/DNN_eval.txt'))
        print(self.root)
        self.t.create_eval()
        self.assertTrue(os.path.isfile(self.root + '/root/results/_DNN_performance/DNN_eval.txt'))
        dnn_eval = pd.read_csv(self.root + '/root/results/_DNN_performance/DNN_eval.txt', delimiter='\t')
        dnn_eval_truth = pd.read_csv(self.root + '/trainer/create_eval.txt', delimiter='\t')
        self.assertEqual(dnn_eval.to_numpy().tolist(), dnn_eval_truth.to_numpy().tolist())

    def test_create_input_data(self):
        self.t.spec = self.oh.name[0]
        X, X_train, X_test, y_train, y_test, test_set, shuffled_X_train, shuffled_X_test = self.t.create_input_data()
        X_truth = np.load(self.root + '/trainer/X.npy')
        X_train_truth = np.load(self.root + '/trainer/X_train.npy')
        X_test_truth = np.load(self.root + '/trainer/X_test.npy')
        y_train_truth = np.load(self.root + '/trainer/y_train.npy')
        y_test_truth = np.load(self.root + '/trainer/y_test.npy')
        test_set_truth = np.load(self.root + '/trainer/test_set.npy')
        shuffled_X_train_truth = np.load(self.root + '/trainer/shuffled_X_train.npy')
        shuffled_X_test_truth = np.load(self.root + '/trainer/shuffled_X_test.npy')
        self.assertEqual(X.tolist(), X_truth.tolist())
        self.assertEqual(X_train.tolist(), X_train_truth.tolist())
        self.assertEqual(X_test.tolist(), X_test_truth.tolist())
        self.assertEqual(y_train.tolist(), y_train_truth.tolist())
        self.assertEqual(y_test.tolist(), y_test_truth.tolist())
        self.assertEqual(test_set.to_numpy().tolist(), test_set_truth.tolist())
        self.assertEqual(shuffled_X_train.tolist(), shuffled_X_train_truth.tolist())
        self.assertEqual(shuffled_X_test.tolist(), shuffled_X_test_truth.tolist())

    def test_create_model_architecture(self):
        self.t.spec = self.oh.name[0]
        X, _, _, _, _, _, _, _ = self.t.create_input_data()
        model = self.t.create_model_architecture(X)
        model_truth = keras.models.load_model(self.root + '/trainer/model.h5')
        self.assertEqual(model.get_config(), model_truth.get_config())
        weights = [x.tolist() for x in model.get_weights()]
        weights_truth = [x.tolist() for x in model_truth.get_weights()]
        self.assertEqual(weights, weights_truth)

    def notest_train_model(self):
        self.t.spec = self.oh.name[0]
        X, X_train, X_test, y_train, y_test, _, _, _ = self.t.create_input_data()
        model = self.t.create_model_architecture(X)
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        tf.Session(config=config)
        AUC, model = self.t.train_model(model, X_train, X_test, y_train, y_test)
        AUC_truth = 0.9930313588850174
        model_truth = keras.models.load_model(self.root + '/trainer/model_trained.h5')
        print(model.get_config())
        self.assertAlmostEqual(AUC, AUC_truth)
        #self.assertEqual(model.get_config(), model_truth.get_config()) ## look into this (it crashes when running the whole test suite but passes when only running this test)
        weights = [x.tolist() for x in model.get_weights()]
        weights_truth = [x.tolist() for x in model_truth.get_weights()]
        if len(weights) == len(weights_truth):
            for list in range(len(weights)):
                if len(weights[list]) == len(weights_truth[list]):
                    for lis in range(len(weights[list])):
                        np.testing.assert_almost_equal(weights[list][lis], weights_truth[list][lis], 6)

    def notest_update_performance_metrics(self):
        self.t.spec = self.oh.name[0]
        X, X_train, X_test, y_train, y_test, _, _, _ = self.t.create_input_data()
        model = self.t.create_model_architecture(X)
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        tf.Session(config=config)
        AUC, model = self.t.train_model(model, X_train, X_test, y_train, y_test)
        os.remove(self.root + '/root/results/_DNN_performance/DNN_eval.txt')
        self.assertFalse(os.path.isfile(self.root + '/root/results/_DNN_performance/DNN_eval.txt'))
        self.t.create_eval()
        self.t.update_performance_metrics()
        self.assertTrue(os.path.isfile(self.root + '/root/results/_DNN_performance/DNN_eval.txt'))
        dnn_eval = pd.read_csv(self.root + '/root/results/_DNN_performance/DNN_eval.txt', delimiter='\t')
        dnn_eval_truth = pd.read_csv(self.root + '/trainer/update_performance_metrics.txt', delimiter='\t')
        self.assertEqual(dnn_eval.to_numpy()[0][0],dnn_eval_truth.to_numpy()[0][0])
        np.testing.assert_almost_equal(dnn_eval.to_numpy()[0][1:], dnn_eval_truth.to_numpy()[0][1:],6)
コード例 #6
0
class PresencePseudoAbsenceTestCase(unittest.TestCase):
    def setUp(self):
        self.root = (os.path.abspath(os.path.join(os.path.dirname(__file__))) +
                     '/test_data').replace('\\', '/')
        self.oh = Occurrences(self.root + '/root')
        self.oh.validate_occurrences()
        self.oh.species_dictionary()
        self.gh = GIS(self.root + '/root')
        self.gh.validate_gis()
        self.gh.validate_tif()
        self.gh.define_output()
        self.ch = Config(self.root, self.oh, self.gh)
        self.ch.search_config()
        self.ch.read_yaml()
        self.ch.random_seed = 1
        self.verbose = False

        self.ppa = PresencePseudoAbsence(self.oh, self.gh, self.ch,
                                         self.verbose)

    def test__init__(self):

        self.assertEqual(self.ppa.oh, self.oh)
        self.assertEqual(self.ppa.gh, self.gh)
        self.assertEqual(self.ppa.ch, self.ch)
        self.assertEqual(self.ppa.verbose, self.verbose)
        self.assertEqual(self.ppa.random_sample_size, self.ch.pseudo_freq)
        self.assertEqual(self.ppa.random_seed, self.ch.random_seed)

    def test_draw_random_absence(self):
        key = self.oh.name[0]
        presence_data, outer_random_sample_lon_lats, sample_size = self.ppa.draw_random_absence(
            key)
        presence_truth = np.load(self.root +
                                 '/presence_pseudo_absence/presence_data.npy',
                                 allow_pickle=True)
        outer_random_sample_lon_lats_truth = np.load(
            self.root + '/presence_pseudo_absence/outer_random_sample.npy')
        self.assertEqual(presence_data.to_numpy().tolist(),
                         presence_truth.tolist())
        self.assertEqual(outer_random_sample_lon_lats.tolist(),
                         outer_random_sample_lon_lats_truth.tolist())
        self.assertEqual(sample_size, self.ch.pseudo_freq)

    def test_create_presence_pseudo_absence(self):
        os.remove(self.root +
                  '/root/spec_ppa/arachis_duranensis_ppa_dataframe.csv')
        os.remove(self.root +
                  '/root/spec_ppa/solanum_bukasovii_ppa_dataframe.csv')
        self.assertFalse(
            os.path.isfile(
                self.root +
                '/root/spec_ppa/arachis_duranensis_ppa_dataframe.csv'))
        self.assertFalse(
            os.path.isfile(
                self.root +
                '/root/spec_ppa/solanum_bukasovii_ppa_dataframe.csv'))
        self.ppa.create_presence_pseudo_absence()
        self.assertTrue(
            os.path.isfile(
                self.root +
                '/root/spec_ppa/arachis_duranensis_ppa_dataframe.csv'))
        self.assertTrue(
            os.path.isfile(
                self.root +
                '/root/spec_ppa/solanum_bukasovii_ppa_dataframe.csv'))
        ppa_a = pd.read_csv(
            self.root + '/root/spec_ppa/arachis_duranensis_ppa_dataframe.csv')
        ppa_b = pd.read_csv(
            self.root + '/root/spec_ppa/solanum_bukasovii_ppa_dataframe.csv')
        truth_a = pd.read_csv(
            self.root +
            '/presence_pseudo_absence/arachis_duranensis_ppa_dataframe.csv')
        truth_b = pd.read_csv(
            self.root +
            '/presence_pseudo_absence/solanum_bukasovii_ppa_dataframe.csv')
        self.assertEqual(ppa_a.to_numpy().tolist(),
                         truth_a.to_numpy().tolist())
        self.assertEqual(ppa_b.to_numpy().tolist(),
                         truth_b.to_numpy().tolist())
コード例 #7
0
ファイル: sdmdl_main.py プロジェクト: yangxhcaf/sdmdl
class sdmdl:
    """sdmdl object with one required parameter: root of the repository, that is holding all occurrences and
    environmental layers. And two additional parameters: dat_root (data root of raster layers) and occ_root (root of
    occurrence files.

    Note: the root of the raster layers and occurrence data can be changed. Be aware that directories provided by the
    user need to contain required files that are present on the GitHub repository.

    :param root: a string representation of the root of the cloned or copied GitHub repository.
    :param dat_root: a string representation of the data directory within the repository. Any files that are present
    in the repositories data folder also need to be present in the directory provided by the user.
    :param occ_root: a string representation of the occurrence directory within the data directory of repository.
    :return: Object. Used to manage all phases of model creation. Handling data preparations, model training and
    prediction.
    """
    def __init__(self, root, dat_root='/data', occ_root='/data/occurrences'):
        """sdmdl object initiation."""

        self.root = root
        self.occ_root = self.root + occ_root if occ_root == '/data/occurrences' else occ_root
        self.dat_root = self.root + dat_root if dat_root == '/data' else dat_root

        self.oh = Occurrences(self.occ_root)
        self.oh.validate_occurrences()
        self.oh.species_dictionary()

        self.gh = GIS(self.dat_root)
        self.gh.validate_gis()
        self.gh.validate_tif()
        self.gh.define_output()

        self.ch = Config(self.dat_root, self.oh, self.gh)
        self.ch.search_config()
        self.ch.read_yaml()

        self.verbose = self.ch.verbose
        if not self.verbose:
            # used to silence tensorflow backend deprecation warnings.
            os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
            logging.getLogger("tensorflow").setLevel(logging.ERROR)

    def reload_config(self):
        """unimplemented, required later for changes to the config file to be automatically detected."""

        pass

    def prep(self):
        """prep function that manages the process of data pre-processing."""

        cpm = PresenceMap(self.oh, self.gh, self.verbose)
        cpm.create_presence_map()

        # currently the raster layers need to be validated again to detect the new presence maps created in the previous
        # step. Adding these presence maps to the list of raster layers could be integrated into the create_presence_map
        # method of the PresenceMap class.

        # Note: This currently leads to unwanted behaviour when:
        # A new sdmdl object is created, the data is already preprocessed, and the user executes the method train
        # without first executing the method prep. This would not be a problem if raster layers were automatically
        # detected but is caused by the creation of the config.yml file that does not including the presence maps.

        self.gh.validate_tif()

        crs = RasterStack(self.gh, self.verbose)
        crs.create_raster_stack()

        ppa = PresencePseudoAbsence(self.oh, self.gh, self.ch, self.verbose)
        ppa.create_presence_pseudo_absence()

        cbm = BandStatistics(self.gh, self.verbose)
        cbm.calc_band_mean_and_stddev()

        ctd = TrainingData(self.oh, self.gh, self.verbose)
        ctd.create_training_df()

        cpd = PredictionData(self.gh, self.verbose)
        cpd.create_prediction_df()

    def train(self):
        """train function that manages the process of model training."""

        th = Trainer(self.oh, self.gh, self.ch, self.verbose)
        th.train()

    def predict(self):
        """predict function that manages the process of model prediction."""

        ph = Predictor(self.oh, self.gh, self.ch, self.verbose)
        ph.predict_model()

    def clean(self):
        """pass."""
        def listdir_if_exists(path):
            if os.path.isdir(path):
                return os.listdir(path)
            else:
                return []

        def rm_if_exists(path):
            if os.path.isfile(path):
                os.remove(path)

        def rmdir_if_exists(path):
            if os.path.isdir(path):
                os.rmdir(path)

        for f in listdir_if_exists(self.gh.non_scaled + '/presence'):
            rm_if_exists(self.gh.non_scaled + '/presence/' + f)
        rmdir_if_exists(self.gh.non_scaled + '/presence')
        rm_if_exists(self.gh.stack + '/stacked_env_variables.tif')
        rmdir_if_exists(self.gh.stack)
        for f in listdir_if_exists(self.gh.spec_ppa):
            rm_if_exists(self.gh.spec_ppa + '/' + f)
        rmdir_if_exists(self.gh.spec_ppa)
        rm_if_exists(self.gh.gis + '/env_bio_mean_std.txt')
        for f in listdir_if_exists(self.gh.spec_ppa_env):
            rm_if_exists(self.gh.spec_ppa_env + '/' + f)
        rmdir_if_exists(self.gh.spec_ppa_env)
        rm_if_exists(self.gh.gis + '/world_prediction_array.npy')
        rm_if_exists(self.gh.gis + '/world_prediction_row_col.csv')
        rm_if_exists(self.gh.root + '/filtered.csv')
コード例 #8
0
ファイル: test_2_Config.py プロジェクト: yangxhcaf/sdmdl
class ConfigTestCase(unittest.TestCase):
    """Test cases for Config Handler class."""
    def setUp(self):
        self.root = (os.path.abspath(os.path.join(os.path.dirname(__file__))) +
                     '/test_data').replace('\\', '/')
        self.oh = Occurrences(self.root + '/root')
        self.oh.validate_occurrences()
        self.oh.species_dictionary()
        self.gh = GIS(self.root + '/root')
        self.gh.validate_gis()
        self.gh.validate_tif()
        self.gh.define_output()

        self.ch = Config(self.root + '/root', self.oh, self.gh)

    def test__init__(self):
        self.assertEqual(self.ch.oh, self.oh)
        self.assertEqual(self.ch.gh, self.gh)
        self.assertEqual(self.ch.root, self.root + '/root')
        self.assertEqual(self.ch.config, [])
        self.assertEqual(self.ch.yml_names, [
            'data_path', 'occurrence_path', 'result_path', 'occurrences',
            'layers', 'random_seed', 'pseudo_freq', 'batchsize', 'epoch',
            'model_layers', 'model_dropout', 'verbose'
        ])
        self.assertEqual(self.ch.data_path, None)
        self.assertEqual(self.ch.occ_path, None)
        self.assertEqual(self.ch.result_path, None)
        self.assertEqual(self.ch.yml, None)
        self.assertEqual(self.ch.random_seed, 0)
        self.assertEqual(self.ch.pseudo_freq, 0)
        self.assertEqual(self.ch.batchsize, 0)
        self.assertEqual(self.ch.epoch, 0)
        self.assertEqual(self.ch.model_layers, [])
        self.assertEqual(self.ch.model_dropout, [])
        self.assertEqual(self.ch.verbose, None)

    def test_search_config(self):
        self.ch.search_config()
        self.assertEqual(self.ch.config, self.root + '/root/config.yml')
        with self.assertRaises(IOError):
            self.ch = Config(self.root + '/config', self.oh, self.gh)
            self.ch.search_config()

    def test_create_yaml(self):
        self.ch.search_config()
        self.ch.config = self.root + '/root/test_config.yml'
        self.ch.create_yaml()
        with open(self.ch.config, 'r') as stream:
            yml = yaml.safe_load(stream)
        self.assertEqual(yml[list(yml.keys())[0]], self.root + '/root')
        self.assertEqual(yml[list(yml.keys())[1]],
                         self.root + '/root/occurrences')
        self.assertEqual(yml[list(yml.keys())[2]], self.root + '/root/results')
        self.assertEqual(yml[list(yml.keys())[3]],
                         dict(zip(self.oh.name, self.oh.path)))
        self.assertEqual(yml[list(yml.keys())[4]],
                         dict(zip(self.gh.names, self.gh.variables)))
        self.assertEqual(yml[list(yml.keys())[5]], 42)
        self.assertEqual(yml[list(yml.keys())[6]], 2000)
        self.assertEqual(yml[list(yml.keys())[7]], 75)
        self.assertEqual(yml[list(yml.keys())[8]], 150)
        self.assertEqual(yml[list(yml.keys())[9]], [250, 200, 150, 100])
        self.assertEqual(yml[list(yml.keys())[10]], [0.3, 0.5, 0.3, 0.5])
        self.assertEqual(yml[list(yml.keys())[11]], True)
        os.remove(self.root + '/root/test_config.yml')

    def test_read_yaml(self):
        self.ch.search_config()
        self.ch.read_yaml()
        self.assertEqual(self.ch.data_path, self.root + '/root')
        self.assertEqual(self.ch.occ_path, self.root + '/root/occurrences')
        self.assertEqual(self.ch.result_path, self.root + '/root/results')
        self.assertEqual(self.ch.oh.name,
                         list(dict(zip(self.oh.name, self.oh.path)).keys()))
        self.assertEqual(self.ch.oh.path,
                         list(dict(zip(self.oh.name, self.oh.path)).values()))
        self.assertEqual(
            self.ch.gh.names,
            list(dict(zip(self.gh.names, self.gh.variables)).keys()))
        self.assertEqual(
            self.ch.gh.variables,
            list(dict(zip(self.gh.names, self.gh.variables)).values()))
        self.assertEqual(self.ch.random_seed, 42)
        self.assertEqual(self.ch.pseudo_freq, 2000)
        self.assertEqual(self.ch.batchsize, 75)
        self.assertEqual(self.ch.epoch, 150)
        self.assertEqual(self.ch.model_layers, [250, 200, 150, 100])
        self.assertEqual(self.ch.model_dropout, [0.3, 0.5, 0.3, 0.5])
        self.assertEqual(self.ch.verbose, True)