Beispiel #1
0
 def test_train(self, mocked_do_eval):
     cli.main(f"train {self.config_file}".split())
     cache_dirs = list(self.conf.dir.glob("zarrcache_*"))
     models = list(self.conf.dir.glob("*.hdf5"))
     self.assertEqual(len(cache_dirs), 1)
     self.assertEqual(len(models), 1)
     self.assertEqual(mocked_do_eval.call_count, 1)
Beispiel #2
0
 def test_apply_creates_predictions(self):
     cli.main(f"apply {self.config_file} {self.nn_hdf5_file}".split())
     plot_dir = pathlib.Path(self.nn_hdf5_file[:-len(".hdf5")])
     for pred_file in [
             "predictions.txt",
             "predictions.pdf",
     ]:
         self.assertTrue((plot_dir / pred_file).exists())
Beispiel #3
0
 def test_eval_creates_plots(self):
     cli.main(f"eval {self.config_file} {self.nn_hdf5_file}".split())
     plot_dir = pathlib.Path(self.nn_hdf5_file[:-len(".hdf5")])
     for plot_file in [
             "genotype_matrices.pdf",
             "roc.pdf",
             "accuracy.pdf",
             "confusion.pdf",
             "reliability.pdf",
     ]:
         self.assertTrue((plot_dir / plot_file).exists())
Beispiel #4
0
    def setUpClass(cls):
        cls.temp_dir = tempfile.TemporaryDirectory()
        cls.config_file = pathlib.Path(cls.temp_dir.name) / "config.toml"
        # load the example toml, patch it, then write it back out
        d = toml.load("examples/test-example.toml")
        d["dir"] = cls.temp_dir.name
        d["vcf"]["phased"] = cls.phased
        with open(cls.config_file, "w") as f:
            toml.dump(d, f)
        cls.conf = config.Config(cls.config_file)

        if cls.need_sim_data:
            with mock.patch("genomatnn.sim._models", new=sim__models()):
                cli.main(f"sim -n 10 --seed 1 {cls.config_file}".split())
                if cls.need_trained_model:
                    cli.main(f"train --seed 1 {cls.config_file}".split())
                    models = list(cls.conf.dir.glob("*.hdf5"))
                    assert len(models) == 1
                    cls.nn_hdf5_file = str(models[0])
Beispiel #5
0
 def test_sim(self):
     cli.main(f"sim -n 2 {self.config_file}".split())
     path = pathlib.Path(self.temp_dir.name)
     ts_files = list(path.glob("**/*.trees"))
     self.assertEqual(len(ts_files), 2 * 3)
Beispiel #6
0
 def test_missing_config_file(self):
     with self.assertRaises(FileNotFoundError):
         cli.main(["sim", "nonexistent.toml"])
Beispiel #7
0
 def test_vcfplot_creates_plot(self):
     plot_file = pathlib.Path(self.temp_dir.name) / "plot.pdf"
     cli.main(
         f"vcfplot {self.config_file} {plot_file} {self.region}".split())
     self.assertTrue(plot_file.exists())
Beispiel #8
0
 def test_missing_config_file(self):
     plot_file = pathlib.Path(self.temp_dir.name) / "plot.pdf"
     with self.assertRaises(FileNotFoundError):
         cli.main(
             f"vcfplot nonexistent.toml {plot_file} {self.region}".split())
Beispiel #9
0
 def test_missing_nn_hdf5_file(self):
     with self.assertRaises((FileNotFoundError, IOError)):
         cli.main(f"apply {self.config_file} nonexistent.hdf5".split())
Beispiel #10
0
 def test_missing_config_file(self):
     with self.assertRaises(FileNotFoundError):
         cli.main(f"apply nonexistent.toml {self.nn_hdf5_file}".split())