Exemplo n.º 1
0
class TestMySQLdbBasic3(unittest.TestCase):
    @classmethod
    def setUpClass(self):
        test_db_utils.create_filled_test_db()
        test_db_utils.create_empty_test_db(db=DB2)

    @classmethod
    def tearDownClass(self):
        test_db_utils.remove_db()
        test_db_utils.remove_db(db=DB2)

    def setUp(self):
        stmt1 = "UPDATE version SET Version = 1"
        test_db_utils.execute(stmt1)
        stmt2 = "UPDATE version SET Version = 0"
        test_db_utils.execute(stmt2, db=DB2)

        self.alchemist = AlchemyHandler(database=DB, username=USER, password=PWD)
        self.alchemist.build_engine()
        self.engine = self.alchemist.engine

    def tearDown(self): 
        self.engine.dispose()

    def test_get_mysql_dbs_2(self):
        """Verify set of databases is retrieved when engine
        is connected to a specific database."""
        databases = mysqldb_basic.get_mysql_dbs(self.engine)
        self.assertTrue(DB in databases)

    def test_get_tables_2(self):
        """Verify set of tables is retrieved when engine
        is connected to the same database."""
        tables = mysqldb_basic.get_tables(self.engine, DB)
        self.assertTrue(TABLE in tables)

    def test_get_tables_3(self):
        """Verify set of tables is retrieved when engine
        is connected to a different database."""
        tables = mysqldb_basic.get_tables(self.engine, DB2)
        self.assertTrue(TABLE in tables)

    def test_get_columns_2(self):
        """Verify set of columns is retrieved when engine
        is not connected to the same database."""
        columns = mysqldb_basic.get_columns(self.engine, DB, TABLE)
        self.assertTrue(COLUMN in columns)

    def test_get_columns_3(self):
        """Verify set of columns is retrieved when engine
        is not connected to a different database."""
        columns = mysqldb_basic.get_columns(self.engine, DB2, TABLE)
        self.assertTrue(COLUMN in columns)
Exemplo n.º 2
0
class TestPhamerationFunctions(unittest.TestCase):
    def setUp(self):

        # Create test database that contains data for several phages.
        test_db_utils.create_filled_test_db()

        self.alchemist = AlchemyHandler(database=DB,
                                        username=USER,
                                        password=PWD)
        self.alchemist.build_engine()
        self.engine = self.alchemist.engine
        self.temp_dir = "/tmp/pdm_utils_tests_phamerate"

    def tearDown(self):
        self.engine.dispose()
        test_db_utils.remove_db()

        run_dir = Path.cwd()
        err_file = run_dir.joinpath("error.log")
        if err_file.exists():
            print("Found leftover blastclust file... removing")
            err_file.unlink()

    def test_1_get_pham_geneids(self):
        """Verify we get back a dictionary"""
        old_phams = get_pham_geneids(self.engine)
        # old_phams should be a dict
        self.assertEqual(type(old_phams), type(dict()))

    def test_2_get_pham_colors(self):
        """Verify we get back a dictionary"""
        old_colors = get_pham_colors(self.engine)
        # old_colors should be a dict
        self.assertEqual(type(old_colors), type(dict()))

    def test_3_get_pham_geneids_and_colors(self):
        """Verify both dictionaries have the same keys"""
        old_phams = get_pham_geneids(self.engine)
        old_colors = get_pham_colors(self.engine)

        # Can't have same keys without the same number of keys...
        with self.subTest():
            self.assertEqual(len(old_phams), len(old_colors))

        # Intersection should be equal to either set of keys - check against old_phams
        with self.subTest():
            self.assertEqual(
                set(old_phams.keys()).intersection(set(old_colors.keys())),
                set(old_phams.keys()))

    def test_4_get_unphamerated_genes(self):
        """Verify we get back a set of length 0"""
        unphamerated = get_new_geneids(self.engine)
        # unphamerated should be a set
        with self.subTest():
            self.assertEqual(type(unphamerated), type(set()))
        # pdm_test_db has 0 unphamerated genes
        with self.subTest():
            self.assertEqual(len(unphamerated), 0)

    def test_5_map_geneids_to_translations(self):
        """Verify we get back a dictionary"""
        gs_to_ts = map_geneids_to_translations(self.engine)

        command = "SELECT distinct(GeneID) FROM gene"
        results = mysqldb_basic.query_dict_list(self.engine, command)

        # gs_to_ts should be a dictionary
        with self.subTest():
            self.assertEqual(type(gs_to_ts), type(dict()))
        # gs_to_ts should have the right number of geneids
        with self.subTest():
            self.assertEqual(len(gs_to_ts), len(results))

    def test_6_map_translations_to_geneids(self):
        """Verify we get back a dictionary"""
        ts_to_gs = map_translations_to_geneids(self.engine)

        command = "SELECT distinct(CONVERT(Translation USING utf8)) FROM gene"
        results = mysqldb_basic.query_dict_list(self.engine, command)

        # ts_to_gs should be a dictionary
        with self.subTest():
            self.assertEqual(type(ts_to_gs), type(dict()))
        # ts_to_gs should have the right number of translations
        with self.subTest():
            self.assertEqual(len(ts_to_gs), len(results))

    def test_7_refresh_tempdir_1(self):
        """Verify if no temp_dir, refresh can make one"""
        if not os.path.exists(self.temp_dir):
            refresh_tempdir(self.temp_dir)
        self.assertTrue(os.path.exists(self.temp_dir))

    def test_8_refresh_tempdir_2(self):
        """Verify if temp_dir with something, refresh makes new empty one"""
        filename = f"{self.temp_dir}/test.txt"
        if not os.path.exists(self.temp_dir):
            refresh_tempdir(self.temp_dir)
        f = open(filename, "w")
        f.write("test\n")
        f.close()

        # Our test file should now exist
        with self.subTest():
            self.assertTrue(os.path.exists(filename))

        # Refresh temp_dir
        refresh_tempdir(self.temp_dir)

        # temp_dir should now exist, but test file should not
        with self.subTest():
            self.assertTrue(os.path.exists(self.temp_dir))
        with self.subTest():
            self.assertFalse(os.path.exists(filename))

    def test_9_write_fasta(self):
        """Verify file gets written properly"""
        filename = f"{self.temp_dir}/input.fasta"

        # refresh_tempdir
        refresh_tempdir(self.temp_dir)

        # Get translations to geneid mappings
        ts_to_gs = map_translations_to_geneids(self.engine)

        # Write fasta
        write_fasta(ts_to_gs, self.temp_dir)

        # Read fasta, make sure number of lines is 2x number of unique translations
        with open(filename, "r") as fh:
            lines = fh.readlines()

        with self.subTest():
            self.assertEqual(len(lines), 2 * len(ts_to_gs))

        # all odd-index lines should map to a key in ts_to_gs
        for i in range(len(lines)):
            if i % 2 == 1:
                with self.subTest():
                    self.assertTrue(
                        lines[i].lstrip(">").rstrip() in ts_to_gs.keys())

    # TODO: comment out this method if you don't have blast_2.2.14 binaries
    def test_10_create_blastdb(self):
        """Verify blastclust database gets made"""
        refresh_tempdir(self.temp_dir)
        db_file = f"{self.temp_dir}/sequenceDB"

        ts_to_gs = map_translations_to_geneids(self.engine)
        write_fasta(ts_to_gs, self.temp_dir)

        create_clusterdb("blast", self.temp_dir)

        # Check that database files were made
        for ext in ["phr", "pin", "psd", "psi", "psq"]:
            with self.subTest():
                self.assertTrue(os.path.exists(f"{db_file}.{ext}"))

    def test_11_create_mmseqsdb(self):
        """Verify mmseqs database gets made"""
        refresh_tempdir(self.temp_dir)
        db_file = f"{self.temp_dir}/sequenceDB"

        ts_to_gs = map_translations_to_geneids(self.engine)
        write_fasta(ts_to_gs, self.temp_dir)

        create_clusterdb("mmseqs", self.temp_dir)

        # Check that database file was made
        self.assertTrue(os.path.exists(db_file))

    def test_12_create_clusterdb(self):
        """Verify no database file gets made"""
        refresh_tempdir(self.temp_dir)
        db_file = f"{self.temp_dir}/sequenceDB"

        ts_to_gs = map_translations_to_geneids(self.engine)
        write_fasta(ts_to_gs, self.temp_dir)

        create_clusterdb("unknown", self.temp_dir)

        # Check that database file was not made
        self.assertFalse(os.path.exists(db_file))

    # TODO: comment out this method if you don't have blast_2.2.14 binaries
    def test_13_phamerate_blast(self):
        """Verify we can phamerate with blastclust"""
        refresh_tempdir(self.temp_dir)

        ts_to_gs = map_translations_to_geneids(self.engine)
        write_fasta(ts_to_gs, self.temp_dir)

        create_clusterdb("blast", self.temp_dir)

        phamerate(get_program_params("blast"), "blast", self.temp_dir)

        # Make sure clustering output file exists
        self.assertTrue(os.path.exists(f"{self.temp_dir}/output.txt"))

    def test_14_phamerate_mmseqs(self):
        """Verify we can phamerate with mmseqs2"""
        refresh_tempdir(self.temp_dir)

        ts_to_gs = map_translations_to_geneids(self.engine)
        write_fasta(ts_to_gs, self.temp_dir)

        create_clusterdb("mmseqs", self.temp_dir)

        phamerate(get_program_params("mmseqs"), "mmseqs", self.temp_dir)

        # Make sure clustering output file exists
        self.assertTrue(os.path.exists(f"{self.temp_dir}/clusterDB.index"))

    def test_15_phamerate_unknown(self):
        """Verify we cannot phamerate with unknown"""
        refresh_tempdir(self.temp_dir)

        ts_to_gs = map_translations_to_geneids(self.engine)
        write_fasta(ts_to_gs, self.temp_dir)

        create_clusterdb("unknown", self.temp_dir)

        phamerate(get_program_params("unknown"), "unknown", self.temp_dir)

        # Make sure clustering output file does not exist
        self.assertFalse(os.path.exists(f"{self.temp_dir}/clusterDB"))

    # TODO: comment out this method if you don't have blast_2.2.14 binaries
    def test_16_parse_blast_output(self):
        """Verify we can open and parse blastclust output"""
        refresh_tempdir(self.temp_dir)

        ts_to_gs = map_translations_to_geneids(self.engine)
        write_fasta(ts_to_gs, self.temp_dir)

        create_clusterdb("blast", self.temp_dir)

        phamerate(get_program_params("blast"), "blast", self.temp_dir)

        phams = parse_output("blast", self.temp_dir)

        # The number of phams should be greater than 0 and less than or equal to
        # the number of distinct translations
        with self.subTest():
            self.assertEqual(type(phams), type(dict()))
        with self.subTest():
            self.assertGreater(len(phams), 0)
        with self.subTest():
            self.assertLessEqual(len(phams), len(ts_to_gs))

    def test_17_parse_mmseqs_output(self):
        """Verify we can open and parse MMseqs2 output"""
        refresh_tempdir(self.temp_dir)

        ts_to_gs = map_translations_to_geneids(self.engine)
        write_fasta(ts_to_gs, self.temp_dir)

        create_clusterdb("mmseqs", self.temp_dir)

        phamerate(get_program_params("mmseqs"), "mmseqs", self.temp_dir)

        phams = parse_output("mmseqs", self.temp_dir)

        # The number of phams should be greater than 0 and less than or equal to
        # the number of distinct translations
        with self.subTest():
            self.assertEqual(type(phams), type(dict()))
        with self.subTest():
            self.assertGreater(len(phams), 0)
        with self.subTest():
            self.assertLessEqual(len(phams), len(ts_to_gs))

    def test_18_parse_unknown_output(self):
        """Verify we cannot open and parse unknown output"""
        refresh_tempdir(self.temp_dir)

        ts_to_gs = map_translations_to_geneids(self.engine)
        write_fasta(ts_to_gs, self.temp_dir)

        create_clusterdb("unknown", self.temp_dir)

        phamerate(get_program_params("unknown"), "unknown", self.temp_dir)

        phams = parse_output("unknown", self.temp_dir)

        # The number of phams should be greater than 0 and less than or equal to
        # the number of distinct translations
        with self.subTest():
            self.assertEqual(type(phams), type(dict()))
        with self.subTest():
            self.assertEqual(len(phams), 0)

    def test_19_reintroduce_duplicates(self):
        """Verify that we can put de-duplicated GeneIDs back together"""
        refresh_tempdir(self.temp_dir)
        gs_to_ts = map_geneids_to_translations(self.engine)

        ts_to_gs = map_translations_to_geneids(self.engine)
        write_fasta(ts_to_gs, self.temp_dir)

        create_clusterdb("mmseqs", self.temp_dir)
        phamerate(get_program_params("mmseqs"), "mmseqs", self.temp_dir)

        new_phams = parse_output("mmseqs", self.temp_dir)

        re_duped_phams = reintroduce_duplicates(new_phams=new_phams,
                                                trans_groups=ts_to_gs,
                                                genes_and_trans=gs_to_ts)

        geneid_total = 0
        for key in re_duped_phams.keys():
            geneid_total += len(re_duped_phams[key])

        # All geneids should be represented in the re_duped_phams
        self.assertEqual(geneid_total, len(gs_to_ts.keys()))

    def test_20_preserve_phams(self):
        """Verify that pham preservation seems to be working"""
        refresh_tempdir(self.temp_dir)

        old_phams = get_pham_geneids(self.engine)
        old_colors = get_pham_colors(self.engine)
        unphamerated = get_new_geneids(self.engine)

        gs_to_ts = map_geneids_to_translations(self.engine)
        ts_to_gs = map_translations_to_geneids(self.engine)

        write_fasta(ts_to_gs, self.temp_dir)

        create_clusterdb("mmseqs", self.temp_dir)

        phamerate(get_program_params("mmseqs"), "mmseqs", self.temp_dir)

        new_phams = parse_output("mmseqs", self.temp_dir)

        new_phams = reintroduce_duplicates(new_phams=new_phams,
                                           trans_groups=ts_to_gs,
                                           genes_and_trans=gs_to_ts)

        final_phams, new_colors = preserve_phams(old_phams=old_phams,
                                                 new_phams=new_phams,
                                                 old_colors=old_colors,
                                                 new_genes=unphamerated)

        # Final phams should be a dict with same number of keys as new_phams
        # since we aren't re-dimensioning, just renaming some keys
        with self.subTest():
            self.assertEqual(type(final_phams), type(dict()))
        with self.subTest():
            self.assertEqual(len(final_phams), len(new_phams))

        # New colors should be a dict with the same number of keys as
        # final_phams
        with self.subTest():
            self.assertEqual(type(new_colors), type(dict()))
        with self.subTest():
            self.assertEqual(len(new_colors), len(final_phams))

        # Can't compare the keys or phams since there's no guarantee that
        # any of the phams were preserved but we can make sure all genes are
        # accounted for
        genes_1_count = len(unphamerated)
        for key in old_phams.keys():
            genes_1_count += len(old_phams[key])
        genes_2_count = 0
        for key in new_phams.keys():
            genes_2_count += len(new_phams[key])
        with self.subTest():
            self.assertEqual(genes_1_count, genes_2_count)
Exemplo n.º 3
0
class TestUpdate(unittest.TestCase):

    @classmethod
    def setUpClass(self):
        test_db_utils.create_filled_test_db()

    @classmethod
    def tearDownClass(self):
        test_db_utils.remove_db()

    def setUp(self):

        self.alchemist = AlchemyHandler(database=DB, username=USER, password=PWD)
        self.alchemist.build_engine()
        test_folder.mkdir()

        # Standardize values in certain fields to define the data
        stmt1 = create_update("phage", "Status", "unknown")
        test_db_utils.execute(stmt1)
        stmt2 = create_update("phage", "HostGenus", "unknown")
        test_db_utils.execute(stmt2)
        stmt3 = create_update("phage", "Accession", "")
        test_db_utils.execute(stmt3)
        stmt4 = create_update("phage", "Cluster", "Z")
        test_db_utils.execute(stmt4)
        stmt5 = create_update("phage", "Subcluster", "Z1")
        test_db_utils.execute(stmt5)
        stmt6 = "UPDATE version SET Version = 0"
        test_db_utils.execute(stmt6)

    def tearDown(self):
        shutil.rmtree(test_folder)




    @patch("pdm_utils.pipelines.update_field.AlchemyHandler")
    def test_main_1(self, alchemy_mock):
        """Verify update runs with empty ticket table."""
        alchemy_mock.return_value = self.alchemist
        create_update_table([], update_table)
        unparsed_args = get_unparsed_args(file=update_table)
        run.main(unparsed_args)
        version_table = test_db_utils.get_data(test_db_utils.version_table_query)
        phage_table = test_db_utils.get_data(test_db_utils.phage_table_query)
        data_dict = phage_id_dict(phage_table)
        alice = data_dict["Alice"]
        trixie = data_dict["Trixie"]
        # Nothing should be different.
        with self.subTest():
            self.assertEqual(alice["HostGenus"], "unknown")
        with self.subTest():
            self.assertEqual(trixie["HostGenus"], "unknown")
        with self.subTest():
            self.assertEqual(version_table[0]["Version"], 0)

    @patch("pdm_utils.pipelines.update_field.AlchemyHandler")
    def test_main_2(self, alchemy_mock):
        """Verify update runs with five tickets in ticket table."""
        alchemy_mock.return_value = self.alchemist
        host_genus = "Mycobacterium"
        cluster = "A"
        subcluster = "A2"
        status = "final"
        accession = "ABC123"
        tkt1 = get_alice_ticket("HostGenus", host_genus)
        tkt2 = get_alice_ticket("Cluster", cluster)
        tkt3 = get_alice_ticket("Subcluster", subcluster)
        tkt4 = get_alice_ticket("Status", status)
        tkt5 = get_alice_ticket("Accession", accession)
        tkts = [tkt1, tkt2, tkt3, tkt4, tkt5]
        create_update_table(tkts, update_table)
        unparsed_args = get_unparsed_args(file=update_table)
        run.main(unparsed_args)
        version_table = test_db_utils.get_data(test_db_utils.version_table_query)
        phage_table = test_db_utils.get_data(test_db_utils.phage_table_query)
        data_dict = phage_id_dict(phage_table)
        alice = data_dict["Alice"]
        trixie = data_dict["Trixie"]
        with self.subTest():
            self.assertEqual(alice["HostGenus"], host_genus)
        with self.subTest():
            self.assertEqual(alice["Cluster"], cluster)
        with self.subTest():
            self.assertEqual(alice["Subcluster"], subcluster)
        with self.subTest():
            self.assertEqual(alice["Accession"], accession)
        with self.subTest():
            self.assertEqual(alice["Status"], status)
        # Just confirm that only Alice data was changed.
        with self.subTest():
            self.assertEqual(trixie["HostGenus"], "unknown")
        with self.subTest():
            self.assertEqual(version_table[0]["Version"], 0)

    @patch("pdm_utils.pipelines.update_field.AlchemyHandler")
    def test_main_3(self, alchemy_mock):
        """Verify version data is updated."""
        alchemy_mock.return_value = self.alchemist
        unparsed_args = get_unparsed_args(version=True)
        run.main(unparsed_args)
        version_table = test_db_utils.get_data(test_db_utils.version_table_query)
        phage_table = test_db_utils.get_data(test_db_utils.phage_table_query)
        data_dict = phage_id_dict(phage_table)
        alice = data_dict["Alice"]
        with self.subTest():
            self.assertEqual(version_table[0]["Version"], 1)
        # Just confirm that only version data was changed.
        with self.subTest():
            self.assertEqual(alice["HostGenus"], "unknown")

    @patch("pdm_utils.pipelines.update_field.AlchemyHandler")
    def test_main_4(self, alchemy_mock):
        """Verify version data and phage table data are updated."""
        alchemy_mock.return_value = self.alchemist
        host_genus = "Mycobacterium"
        tkt = get_alice_ticket("HostGenus", host_genus)
        create_update_table([tkt], update_table)
        unparsed_args = get_unparsed_args(file=update_table, version=True)
        run.main(unparsed_args)
        version_table = test_db_utils.get_data(test_db_utils.version_table_query)
        phage_table = test_db_utils.get_data(test_db_utils.phage_table_query)
        data_dict = phage_id_dict(phage_table)
        alice = data_dict["Alice"]
        with self.subTest():
            self.assertEqual(alice["HostGenus"], host_genus)
        with self.subTest():
            self.assertEqual(version_table[0]["Version"], 1)
Exemplo n.º 4
0
class TestMysqldbBasic6(unittest.TestCase):
    @classmethod
    def setUpClass(self):
        test_db_utils.create_empty_test_db(db=DB2) 
        test_db_utils.create_empty_test_db()
        

        phage_data1 = test_data_utils.get_trixie_phage_data()
        phage_data2 = test_data_utils.get_trixie_phage_data()
        phage_data3 = test_data_utils.get_trixie_phage_data()

        phage_data1["PhageID"] = "L5"
        phage_data2["PhageID"] = "Trixie"
        phage_data3["PhageID"] = "D29"

        phage_data1["HostGenus"] = "Mycobacterium"
        phage_data2["HostGenus"] = "Mycobacterium"
        phage_data3["HostGenus"] = "Gordonia"

        phage_data1["Accession"] = "ABC123"
        phage_data2["Accession"] = "XYZ456"
        phage_data3["Accession"] = ""

        phage_data1["Cluster"] = "A"
        phage_data2["Cluster"] = "B"
        phage_data3["Cluster"] = "NULL"

        phage_data1["Subcluster"] = "A1"
        phage_data2["Subcluster"] = "NULL"
        phage_data3["Subcluster"] = "NULL"

        phage_data1["Sequence"] = "atcg"
        phage_data2["Sequence"] = "AATT"
        phage_data3["Sequence"] = "GGCC"

        phage_data1["Length"] = 6
        phage_data2["Length"] = 4
        phage_data3["Length"] = 5

        phage_data1["DateLastModified"] = constants.EMPTY_DATE
        phage_data2["DateLastModified"] = constants.EMPTY_DATE
        phage_data3["DateLastModified"] = constants.EMPTY_DATE

        phage_data_list = [phage_data1, phage_data2, phage_data3]
        for phage_data in phage_data_list:
            test_db_utils.insert_data(PHAGE, phage_data)

        gene_data1 = test_data_utils.get_trixie_gene_data()
        gene_data2 = test_data_utils.get_trixie_gene_data()
        gene_data3 = test_data_utils.get_trixie_gene_data()
        gene_data4 = test_data_utils.get_trixie_gene_data()

        gene_data1["PhageID"] = "Trixie"
        gene_data2["PhageID"] = "Trixie"
        gene_data3["PhageID"] = "Trixie"
        gene_data4["PhageID"] = "D29"

        gene_data1["GeneID"] = "Trixie_1"
        gene_data2["GeneID"] = "Trixie_2"
        gene_data3["GeneID"] = "Trixie_3"
        gene_data4["GeneID"] = "D29_1"

        gene_data_list = [gene_data1, gene_data2, gene_data3, gene_data4]
        for gene_data in gene_data_list:
            test_db_utils.insert_data(GENE, gene_data)

    @classmethod
    def tearDownClass(self): 
        test_db_utils.remove_db()
        test_db_utils.remove_db(db=DB2)

    def setUp(self):
        self.alchemist1 = AlchemyHandler(database=DB, username=USER, password=PWD)
        self.alchemist1.build_engine()
        self.engine1 = self.alchemist1.engine

        self.alchemist2 = AlchemyHandler(database=DB2, username=USER, password=PWD)
        self.alchemist2.build_engine()
        self.engine2 = self.alchemist2.engine

    def tearDown(self):
        self.engine1.dispose()
        self.engine2.dispose()

    def test_get_distinct_1(self):
        """Retrieve a set of all distinct values when data is not present."""
        result = mysqldb_basic.get_distinct(self.engine2, "phage", "PhageID")
        exp = set()
        self.assertEqual(result, exp)

    def test_get_distinct_2(self):
        """Retrieve a set of all distinct values when data is present."""
        result1 = mysqldb_basic.get_distinct(
                        self.engine1, "phage", "PhageID")
        result2 = mysqldb_basic.get_distinct(
                        self.engine1, "phage", "HostGenus", null="test")
        result3 = mysqldb_basic.get_distinct(
                        self.engine1, "phage", "Accession")
        result4 = mysqldb_basic.get_distinct(
                        self.engine1, "phage", "Cluster", null="Singleton")
        result5 = mysqldb_basic.get_distinct(
                        self.engine1, "phage", "Subcluster", null="none")

        exp_phage_id = {"L5", "Trixie", "D29"}
        exp_host_genus = {"Mycobacterium", "Gordonia"}
        exp_accession = {"ABC123", "XYZ456", ""}
        exp_cluster = {"A", "B", "Singleton"}
        exp_subcluster = {"A1", "none"}

        with self.subTest():
            self.assertEqual(result1, exp_phage_id)
        with self.subTest():
            self.assertEqual(result2, exp_host_genus)
        with self.subTest():
            self.assertEqual(result3, exp_accession)
        with self.subTest():
            self.assertEqual(result4, exp_cluster)
        with self.subTest():
            self.assertEqual(result5, exp_subcluster)

    def test_retrieve_data_1(self):
        """Verify that a dictionary of data is retrieved for a valid PhageID."""
        result_list = mysqldb_basic.retrieve_data(
                        self.engine1, column="PhageID",
                        id_list=["L5"], query=PHAGE_QUERY)
        with self.subTest():
            self.assertEqual(len(result_list[0].keys()), 14)
        with self.subTest():
            self.assertEqual(result_list[0]["PhageID"], "L5")

    def test_retrieve_data_2(self):
        """Verify that an empty list is retrieved for an invalid PhageID."""
        result_list = mysqldb_basic.retrieve_data(
                        self.engine1, column="PhageID",
                        id_list=["EagleEye"], query=PHAGE_QUERY)
        self.assertEqual(len(result_list), 0)

    def test_retrieve_data_3(self):
        """Verify that dictionaries of data are retrieved for a list of two
        valid PhageIDs."""
        result_list = mysqldb_basic.retrieve_data(
                        self.engine1, column="PhageID",
                        id_list=["L5","Trixie"], query=PHAGE_QUERY)
        self.assertEqual(len(result_list), 2)

    def test_retrieve_data_4(self):
        """Verify that dictionaries of data are retrieved for a list of three
        valid PhageIDs and one invalid PhageID."""
        result_list = mysqldb_basic.retrieve_data(
                        self.engine1, column="PhageID",
                        id_list=["L5","Trixie","EagleEye","D29"],
                        query=PHAGE_QUERY)
        self.assertEqual(len(result_list), 3)

    def test_retrieve_data_5(self):
        """Verify that dictionaries of data are retrieved for multiple
        valid PhageIDs when no list is provided."""
        result_list = mysqldb_basic.retrieve_data(
                        self.engine1, query=PHAGE_QUERY)
        self.assertEqual(len(result_list), 3)

    def test_retrieve_data_6(self):
        """Verify that a list of CDS data is retrieved for a valid PhageID."""
        result_list = mysqldb_basic.retrieve_data(
                        self.engine1, column="PhageID",
                        id_list=["Trixie"], query=GENE_QUERY)
        with self.subTest():
            self.assertEqual(len(result_list), 3)
        with self.subTest():
            self.assertEqual(len(result_list[0].keys()), 13)
        with self.subTest():
            self.assertEqual(result_list[0]["PhageID"], "Trixie")

    def test_retrieve_data_7(self):
        """Verify that an empty list of CDS data is retrieved
        for an invalid PhageID."""
        result_list = mysqldb_basic.retrieve_data(
                        self.engine1, column="PhageID",
                        id_list=["L5"], query=GENE_QUERY)
        self.assertEqual(len(result_list), 0)

    def test_retrieve_data_8(self):
        """Verify that a list of all CDS data is retrieved when no
        PhageID is provided."""
        result_list = mysqldb_basic.retrieve_data(
                                self.engine1, query=GENE_QUERY)
        self.assertEqual(len(result_list), 4)
Exemplo n.º 5
0
class TestMysqldbBasic1(unittest.TestCase):
    @classmethod
    def setUpClass(self):
        test_db_utils.create_empty_test_db()

    @classmethod
    def tearDownClass(self):
        if test_db_utils.check_if_exists():
            test_db_utils.remove_db()

        if test_db_utils.check_if_exists(db=DB2):
            test_db_utils.remove_db(db=DB2)

    def setUp(self):
        if not test_db_utils.check_if_exists():
            test_db_utils.create_empty_test_db()
        else:
            if len(test_db_utils.execute(TABLES_QUERY.format(DB))) == 0:
                test_db_utils.install_db(test_db_utils.SCHEMA_FILEPATH)

        self.alchemist = AlchemyHandler(username=USER, password=PWD)
        self.alchemist.build_engine()
        self.engine = self.alchemist.engine

    def tearDown(self):
        if test_db_utils.check_if_exists(db=DB2):
            test_db_utils.remove_db(db=DB2)

        self.engine.dispose()

    def test_drop_db_1(self):
        """Verify existing database is dropped."""
        before = test_db_utils.check_if_exists()
        result = mysqldb_basic.drop_db(self.engine, DB)
        after = test_db_utils.check_if_exists()
        with self.subTest():
            self.assertTrue(before)
        with self.subTest():
            self.assertFalse(after)
        with self.subTest():
            self.assertEqual(result, 0)

    def test_drop_db_2(self):
        """Verify non-existing database is not dropped."""
        before = test_db_utils.check_if_exists(db=DB2)
        result = mysqldb_basic.drop_db(self.engine, DB2)
        after = test_db_utils.check_if_exists(db=DB2)
        with self.subTest():
            self.assertFalse(before)
        with self.subTest():
            self.assertFalse(after)
        with self.subTest():
            self.assertEqual(result, 1)

    def test_create_db_1(self):
        """Verify new database is created."""
        before = test_db_utils.check_if_exists(db=DB2)
        result = mysqldb_basic.create_db(self.engine, DB2)
        after = test_db_utils.check_if_exists(db=DB2)
        with self.subTest():
            self.assertFalse(before)
        with self.subTest():
            self.assertTrue(after)
        with self.subTest():
            self.assertEqual(result, 0)

    def test_create_db_2(self):
        """Verify already-existing database is not created."""
        before = test_db_utils.check_if_exists()
        result = mysqldb_basic.create_db(self.engine, DB)
        after = test_db_utils.check_if_exists()
        with self.subTest():
            self.assertTrue(before)
        with self.subTest():
            self.assertTrue(after)
        with self.subTest():
            self.assertEqual(result, 1)

    def test_drop_create_db_1(self):
        """Verify already-existing database is dropped and created."""
        before = test_db_utils.check_if_exists()
        before_tables = test_db_utils.execute(TABLES_QUERY.format(DB))
        result = mysqldb_basic.drop_create_db(self.engine, DB)
        after = test_db_utils.check_if_exists()
        after_tables = test_db_utils.execute(TABLES_QUERY.format(DB))
        with self.subTest():
            self.assertTrue(before)
        with self.subTest():
            self.assertTrue(after)
        with self.subTest():
            self.assertEqual(result, 0)
        with self.subTest():
            self.assertTrue(len(before_tables) > 0)
        with self.subTest():
            self.assertTrue(len(after_tables) == 0)

    def test_drop_create_db_2(self):
        """Verify non-existing database is not dropped but is created."""
        before = test_db_utils.check_if_exists(DB2)
        result = mysqldb_basic.drop_create_db(self.engine, DB2)
        after = test_db_utils.check_if_exists(DB2)
        after_tables = test_db_utils.execute(TABLES_QUERY.format(DB2), db=DB2)
        with self.subTest():
            self.assertFalse(before)
        with self.subTest():
            self.assertTrue(after)
        with self.subTest():
            self.assertEqual(result, 0)
        with self.subTest():
            self.assertTrue(len(after_tables) == 0)

    @patch("pdm_utils.functions.mysqldb_basic.drop_db")
    def test_drop_create_db_3(self, drop_mock):
        """Verify database is not created if there is an error during drop."""
        drop_mock.return_value = 1
        before = test_db_utils.check_if_exists()
        before_tables = test_db_utils.execute(TABLES_QUERY.format(DB))

        result = mysqldb_basic.drop_create_db(self.engine, DB)
        after = test_db_utils.check_if_exists()
        after_tables = test_db_utils.execute(TABLES_QUERY.format(DB))

        with self.subTest():
            self.assertTrue(before)
        with self.subTest():
            self.assertTrue(after)
        with self.subTest():
            self.assertEqual(result, 1)
        with self.subTest():
            self.assertTrue(len(before_tables) > 0)
        with self.subTest():
            self.assertTrue(len(after_tables) > 0)
        with self.subTest():
            self.assertEqual(len(before_tables), len(after_tables))
        with self.subTest():
            drop_mock.assert_called()
Exemplo n.º 6
0
class TestMysqldbBasic5(unittest.TestCase):
    @classmethod
    def setUpClass(self): 
        test_db_utils.create_empty_test_db()

    @classmethod
    def tearDownClass(self):
        test_db_utils.remove_db()

    def setUp(self):
        self.alchemist = AlchemyHandler(database=DB, username=USER, password=PWD)
        self.alchemist.build_engine()
        self.engine = self.alchemist.engine

    def tearDown(self):
        test_db_utils.execute("DELETE FROM phage") 
        self.engine.dispose()

    def test_get_table_count_1(self):
        """Verify the correct number of phages is returned when
        the database is empty."""
        count = mysqldb_basic.get_table_count(self.engine, TABLE)
        self.assertEqual(count, 0)

    def test_get_table_count_2(self):
        """Verify the correct number of phages is returned when
        the database contains one genome."""
        phage_data = test_data_utils.get_trixie_phage_data()
        test_db_utils.insert_data(PHAGE, phage_data)
        count = mysqldb_basic.get_table_count(self.engine, TABLE)
        self.assertEqual(count, 1)

    def test_get_first_row_data_1(self):
        """Verify empty dictionary is returned when there is no data."""
        data = mysqldb_basic.get_first_row_data(self.engine, TABLE)
        self.assertEqual(len(data.keys()), 0)

    def test_get_first_row_data_2(self):
        """Verify dictionary is returned when there is one row of data."""
        phage_data = test_data_utils.get_trixie_phage_data()
        test_db_utils.insert_data(PHAGE, phage_data)
        data = mysqldb_basic.get_first_row_data(self.engine, TABLE)
        self.assertTrue(COLUMN in data.keys())

    def test_get_first_row_data_3(self):
        """Verify dictionary is returned when there are two rows of data."""
        phage_data1 = test_data_utils.get_trixie_phage_data()
        phage_data2 = test_data_utils.get_trixie_phage_data()
        phage_data1["PhageID"] = "Trixie"
        phage_data2["PhageID"] = "L5"
        test_db_utils.insert_data(PHAGE, phage_data1)
        test_db_utils.insert_data(PHAGE, phage_data2)
        # Get all data from table just to confirm there is more than one row.
        all_data = test_db_utils.get_data(test_db_utils.phage_table_query)
        data = mysqldb_basic.get_first_row_data(self.engine, TABLE)
        with self.subTest():
            self.assertEqual(len(all_data), 2)
        with self.subTest():
            self.assertTrue(COLUMN in data.keys())

    def test_first_1(self):
        """Verify dictionary is returned when there are two rows of data."""
        phage_data1 = test_data_utils.get_trixie_phage_data()
        phage_data2 = test_data_utils.get_trixie_phage_data()
        phage_data1["PhageID"] = "Trixie"
        phage_data2["PhageID"] = "L5"
        test_db_utils.insert_data(PHAGE, phage_data1)
        test_db_utils.insert_data(PHAGE, phage_data2)
        data = mysqldb_basic.first(self.engine, PHAGE_QUERY, return_dict=True)
        self.assertTrue(COLUMN in data.keys())

    def test_first_2(self):
        """Verify tuple is returned when there are two rows of data."""
        phage_data1 = test_data_utils.get_trixie_phage_data()
        phage_data2 = test_data_utils.get_trixie_phage_data()
        phage_data1["PhageID"] = "Trixie"
        phage_data2["PhageID"] = "L5"
        test_db_utils.insert_data(PHAGE, phage_data1)
        test_db_utils.insert_data(PHAGE, phage_data2)
        data = mysqldb_basic.first(self.engine, PHAGE_QUERY, return_dict=False)
        with self.subTest():
            self.assertIsInstance(data, tuple)
        with self.subTest():
            self.assertTrue(len(data) > 1)

    def test_scalar_1(self):
        """Verify dictionary is returned when there are two rows of data."""
        phage_data1 = test_data_utils.get_trixie_phage_data()
        phage_data2 = test_data_utils.get_trixie_phage_data()
        phage_data1["PhageID"] = "Trixie"
        phage_data2["PhageID"] = "L5"
        test_db_utils.insert_data(PHAGE, phage_data1)
        test_db_utils.insert_data(PHAGE, phage_data2)
        count = mysqldb_basic.scalar(self.engine, COUNT_QUERY)
        self.assertEqual(count, 2)
Exemplo n.º 7
0
class TestMysqldbBasic4(unittest.TestCase):
    def setUp(self):
        test_db_utils.create_filled_test_db()
        test_db_utils.create_empty_test_db(db=DB2)

        stmt1 = "UPDATE version SET Version = 1"
        test_db_utils.execute(stmt1)
        stmt2 = "UPDATE version SET Version = 0"
        test_db_utils.execute(stmt2, db=DB2)

        self.alchemist = AlchemyHandler(database=DB, username=USER, password=PWD)
        self.alchemist.build_engine()
        self.engine = self.alchemist.engine

    def tearDown(self):
        self.engine.dispose()
        test_db_utils.remove_db()
        test_db_utils.remove_db(db=DB2)

    def test_copy_db_1(self):
        """Verify data from database1 is copied to database2."""
        before_v1 = test_db_utils.get_data(test_db_utils.version_table_query)
        before_v2 = test_db_utils.get_data(test_db_utils.version_table_query, db=DB2)
        result = mysqldb_basic.copy_db(self.engine, DB2)
        after_v1 = test_db_utils.get_data(test_db_utils.version_table_query)
        after_v2 = test_db_utils.get_data(test_db_utils.version_table_query, db=DB2)
        with self.subTest():
            self.assertNotEqual(before_v1[0]["Version"], before_v2[0]["Version"])
        with self.subTest():
            self.assertEqual(result, 0)
        with self.subTest():
            self.assertEqual(after_v1[0]["Version"], after_v2[0]["Version"])

    def test_copy_db_2(self):
        """Verify no data is copied since databases are the same."""
        before_v1 = test_db_utils.get_data(test_db_utils.version_table_query)
        result = mysqldb_basic.copy_db(self.engine, DB)
        after_v1 = test_db_utils.get_data(test_db_utils.version_table_query)
        with self.subTest():
            self.assertEqual(before_v1[0]["Version"], after_v1[0]["Version"])
        with self.subTest():
            self.assertEqual(result, 0)

    def test_copy_db_3(self):
        """Verify no data is copied since new database does not exist."""
        result = mysqldb_basic.copy_db(self.engine, DB3)
        self.assertEqual(result, 1)

    @patch("pdm_utils.functions.mysqldb_basic.pipe_commands")
    def test_copy_db_4(self, pc_mock):
        """Verify no data is copied if en error is encountered during copying."""
        # Raise an error instead of calling pipe_commands() so that
        # the exception block is entered.
        pc_mock.side_effect = ValueError("Error raised")
        result = mysqldb_basic.copy_db(self.engine, DB2)
        self.assertEqual(result, 1)

    def test_install_db_1(self):
        """Verify new database is installed."""
        stmt1 = "UPDATE version SET Version = 0"
        test_db_utils.execute(stmt1)
        before = test_db_utils.get_data(test_db_utils.version_table_query)
        result = mysqldb_basic.install_db(self.engine, test_db_utils.TEST_DB_FILEPATH)
        after = test_db_utils.get_data(test_db_utils.version_table_query)
        with self.subTest():
            self.assertNotEqual(before[0]["Version"], after[0]["Version"])
        with self.subTest():
            self.assertTrue(after[0]["Version"] > 0)
        with self.subTest():
            self.assertEqual(result, 0)

    @patch("subprocess.check_call")
    def test_install_db_2(self, cc_mock):
        """Verify new database is not installed."""
        # Raise an error instead of calling check_call() so that
        # the exception block is entered.
        cc_mock.side_effect = ValueError("Error raised")
        stmt1 = "UPDATE version SET Version = 0"
        test_db_utils.execute(stmt1)
        before = test_db_utils.get_data(test_db_utils.version_table_query)
        result = mysqldb_basic.install_db(self.engine, test_db_utils.TEST_DB_FILEPATH)
        after = test_db_utils.get_data(test_db_utils.version_table_query)
        with self.subTest():
            self.assertEqual(before[0]["Version"], after[0]["Version"])
        with self.subTest():
            self.assertTrue(after[0]["Version"] == 0)
        with self.subTest():
            self.assertEqual(result, 1)
Exemplo n.º 8
0
class TestGetGBRecords(unittest.TestCase):
    @classmethod
    def setUpClass(self):
        test_db_utils.create_filled_test_db()

    @classmethod
    def tearDownClass(self):
        test_db_utils.remove_db()

    def setUp(self):
        test_folder.mkdir()
        self.alchemist = AlchemyHandler(database=DB,
                                        username=USER,
                                        password=PWD)
        self.alchemist.build_engine()
        # Standardize values in certain fields to define the data
        stmt1 = create_update("phage", "Status", "draft")
        test_db_utils.execute(stmt1)
        stmt2 = create_update("phage", "HostGenus", "Mycobacterium")
        test_db_utils.execute(stmt2)
        stmt3 = create_update("phage", "Accession", "")
        test_db_utils.execute(stmt3)
        stmt4 = create_update("gene", "Notes", "repressor")
        test_db_utils.execute(stmt4)
        stmt5 = "UPDATE version SET Version = 1"
        test_db_utils.execute(stmt5)
        self.unparsed_args = get_unparsed_args()

    def tearDown(self):
        shutil.rmtree(test_folder)

    @patch("pdm_utils.pipelines.get_gb_records"
           ".pipelines_basic.build_alchemist")
    def test_main_1(self, alchemy_mock):
        """Verify no GenBank record is retrieved."""
        alchemy_mock.return_value = self.alchemist
        run.main(self.unparsed_args)
        count = count_files(results_path)
        with self.subTest():
            self.assertTrue(results_path.exists())
        with self.subTest():
            self.assertEqual(count, 0)

    @patch("pdm_utils.pipelines.get_gb_records"
           ".pipelines_basic.build_alchemist")
    def test_main_2(self, alchemy_mock):
        """Verify one GenBank record is retrieved."""
        alchemy_mock.return_value = self.alchemist
        stmt = create_update("phage", "Accession", TRIXIE_ACC, "Trixie")
        test_db_utils.execute(stmt)
        run.main(self.unparsed_args)
        count = count_files(results_path)
        with self.subTest():
            self.assertTrue(results_path.exists())
        with self.subTest():
            self.assertEqual(count, 1)

    @patch("pdm_utils.pipelines.get_gb_records"
           ".pipelines_basic.build_alchemist")
    def test_main_3(self, alchemy_mock):
        """Verify no GenBank record is retrieved based on one filter."""
        alchemy_mock.return_value = self.alchemist
        stmt = create_update("phage", "Accession", TRIXIE_ACC, "Trixie")
        test_db_utils.execute(stmt)
        self.unparsed_args.extend(["-f", f"phage.Accession!={TRIXIE_ACC}"])
        run.main(self.unparsed_args)
        count = count_files(results_path)
        with self.subTest():
            self.assertTrue(results_path.exists())
        with self.subTest():
            self.assertEqual(count, 0)

    @patch("pdm_utils.pipelines.get_gb_records"
           ".pipelines_basic.build_alchemist")
    def test_main_4(self, alchemy_mock):
        """Verify one GenBank record is retrieved based on one filter."""
        alchemy_mock.return_value = self.alchemist
        stmt1 = create_update("phage", "Accession", TRIXIE_ACC, "Trixie")
        test_db_utils.execute(stmt1)
        stmt2 = create_update("phage", "Status", "final", "Trixie")
        test_db_utils.execute(stmt2)
        self.unparsed_args.extend(["-f", "phage.Status!=draft"])
        run.main(self.unparsed_args)
        count = count_files(results_path)
        with self.subTest():
            self.assertTrue(results_path.exists())
        with self.subTest():
            self.assertEqual(count, 1)
Exemplo n.º 9
0
class TestAlchemyHandler(unittest.TestCase):
    @classmethod
    def setUpClass(self):
        test_db_utils.create_filled_test_db()

    def setUp(self):
        self.alchemist = AlchemyHandler()

    def test_validate_database_1(self):
        """Verify validate_database() detects good database access.
        """
        self.alchemist.username = user
        self.alchemist.password = pwd
        self.alchemist.build_engine()

        self.alchemist.database = db
        self.alchemist.validate_database()

    def test_validate_database_2(self):
        """Verify validate_database() detects bad database access.
        """
        self.alchemist.username = user
        self.alchemist.password = pwd
        self.alchemist.build_engine()

        self.alchemist.database = "not_database"
        with self.assertRaises(MySQLDatabaseError):
            self.alchemist.validate_database()

    def test_build_engine_1(self):
        """Verify build_engine() creates and stores Engine object.
        """
        self.alchemist.username = user
        self.alchemist.password = pwd
        self.alchemist.build_engine()

        self.assertTrue(isinstance(self.alchemist.engine, Engine))
        self.assertTrue(self.alchemist.connected)

    def test_build_engine_2(self):
        """Verify build_engine() connects to database if has_database.
        """
        self.alchemist.username = user
        self.alchemist.password = pwd
        self.alchemist.database = db

        self.alchemist.build_engine()

        self.assertTrue(self.alchemist.connected_database)

    def connect_to_pdm_test_db(self):
        """Sets alchemist credentials and database to connect to pdm_test_db.
        """
        self.alchemist.username = user
        self.alchemist.password = pwd
        self.alchemist.database = db

    def test_build_metadata_1(self):
        """Verify build_metadata() creates and stores MetaData and Engine.
        """
        self.connect_to_pdm_test_db()

        self.assertTrue(isinstance(self.alchemist.metadata, MetaData))
        self.assertTrue(isinstance(self.alchemist.engine, Engine))

        self.assertTrue(self.alchemist._graph is None)
        self.assertTrue(self.alchemist._session is None)
        self.assertTrue(self.alchemist._mapper is None)

    def test_build_graph_1(self):
        """Verify build_graph() creates and stores, Graph, MetaData, and Engine.
        """
        self.connect_to_pdm_test_db()

        self.assertTrue(isinstance(self.alchemist.graph, Graph))
        self.assertTrue(isinstance(self.alchemist.metadata, MetaData))
        self.assertTrue(isinstance(self.alchemist.engine, Engine))

        self.assertTrue(self.alchemist._session is None)
        self.assertTrue(self.alchemist._mapper is None)

    def test_build_session_1(self):
        """Verify build_session() creates and stores Engine and Session
        """
        self.connect_to_pdm_test_db()

        self.assertTrue(isinstance(self.alchemist.session, Session))
        self.assertTrue(isinstance(self.alchemist.engine, Engine))

        self.assertTrue(self.alchemist._graph is None)
        self.assertTrue(self.alchemist._metadata is None)
        self.assertTrue(self.alchemist._mapper is None)

    def test_build_mapper_1(self):
        """Verify build_mapper() creates and stores, Mapper, MetaData, and Engine.
        """
        self.connect_to_pdm_test_db()

        self.assertTrue(isinstance(self.alchemist.mapper, DeclarativeMeta))
        self.assertTrue(isinstance(self.alchemist.metadata, MetaData))
        self.assertTrue(isinstance(self.alchemist.engine, Engine))

        self.assertTrue(self.alchemist._session is None)
        self.assertTrue(self.alchemist._graph is None)

    def test_engine_1(self):
        """Verify AlchemyHandler extracts credentials from engine.
        """
        engine = create_engine(
            self.alchemist.construct_engine_string(username=user,
                                                   password=pwd,
                                                   database=db))

        self.alchemist.engine = engine

        self.assertEqual(self.alchemist.username, user)
        self.assertEqual(self.alchemist.password, pwd)
        self.assertEqual(self.alchemist.database, db)

    @classmethod
    def tearDownClass(self):
        test_db_utils.remove_db()
Exemplo n.º 10
0
class TestMysqldbBasic1(unittest.TestCase):
    def setUp(self):
        test_db_utils.create_filled_test_db()
        self.alchemist = AlchemyHandler(username=USER, password=PWD)
        self.alchemist.build_engine()
        self.engine = self.alchemist.engine

    def tearDown(self):
        self.engine.dispose()
        # Remove 'pdm_test_db'
        exists = test_db_utils.check_if_exists()
        if exists:
            test_db_utils.remove_db()

        # Remove 'pdm_test_2'
        exists = test_db_utils.check_if_exists(db=DB2)
        if exists:
            test_db_utils.remove_db(db=DB2)

    def test_drop_db_1(self):
        """Verify existing database is dropped."""
        before = test_db_utils.check_if_exists()
        result = mysqldb_basic.drop_db(self.engine, DB)
        after = test_db_utils.check_if_exists()
        with self.subTest():
            self.assertTrue(before)
        with self.subTest():
            self.assertFalse(after)
        with self.subTest():
            self.assertEqual(result, 0)

    def test_drop_db_2(self):
        """Verify non-existing database is not dropped."""
        before = test_db_utils.check_if_exists(db=DB2)
        result = mysqldb_basic.drop_db(self.engine, DB2)
        after = test_db_utils.check_if_exists(db=DB2)
        with self.subTest():
            self.assertFalse(before)
        with self.subTest():
            self.assertFalse(after)
        with self.subTest():
            self.assertEqual(result, 1)

    def test_create_db_1(self):
        """Verify new database is created."""
        before = test_db_utils.check_if_exists(db=DB2)
        result = mysqldb_basic.create_db(self.engine, DB2)
        after = test_db_utils.check_if_exists(db=DB2)
        with self.subTest():
            self.assertFalse(before)
        with self.subTest():
            self.assertTrue(after)
        with self.subTest():
            self.assertEqual(result, 0)

    def test_create_db_2(self):
        """Verify already-existing database is not created."""
        before = test_db_utils.check_if_exists()
        result = mysqldb_basic.create_db(self.engine, DB)
        after = test_db_utils.check_if_exists()
        with self.subTest():
            self.assertTrue(before)
        with self.subTest():
            self.assertTrue(after)
        with self.subTest():
            self.assertEqual(result, 1)

    def test_drop_create_db_1(self):
        """Verify already-existing database is dropped and created."""
        before = test_db_utils.check_if_exists()
        before_tables = test_db_utils.execute(TABLES_QUERY.format(DB))
        result = mysqldb_basic.drop_create_db(self.engine, DB)
        after = test_db_utils.check_if_exists()
        after_tables = test_db_utils.execute(TABLES_QUERY.format(DB))
        with self.subTest():
            self.assertTrue(before)
        with self.subTest():
            self.assertTrue(after)
        with self.subTest():
            self.assertEqual(result, 0)
        with self.subTest():
            self.assertTrue(len(before_tables) > 0)
        with self.subTest():
            self.assertTrue(len(after_tables) == 0)

    def test_drop_create_db_2(self):
        """Verify non-existing database is not dropped but is created."""
        before = test_db_utils.check_if_exists(DB2)
        result = mysqldb_basic.drop_create_db(self.engine, DB2)
        after = test_db_utils.check_if_exists(DB2)
        after_tables = test_db_utils.execute(TABLES_QUERY.format(DB2), db=DB2)
        with self.subTest():
            self.assertFalse(before)
        with self.subTest():
            self.assertTrue(after)
        with self.subTest():
            self.assertEqual(result, 0)
        with self.subTest():
            self.assertTrue(len(after_tables) == 0)

    @patch("pdm_utils.functions.mysqldb_basic.drop_db")
    def test_drop_create_db_3(self, drop_mock):
        """Verify database is not created if there is an error during drop."""
        drop_mock.return_value = 1
        before = test_db_utils.check_if_exists()
        before_tables = test_db_utils.execute(TABLES_QUERY.format(DB))

        result = mysqldb_basic.drop_create_db(self.engine, DB)
        after = test_db_utils.check_if_exists()
        after_tables = test_db_utils.execute(TABLES_QUERY.format(DB))
        with self.subTest():
            self.assertTrue(before)
        with self.subTest():
            self.assertTrue(after)
        with self.subTest():
            self.assertEqual(result, 1)
        with self.subTest():
            self.assertTrue(len(before_tables) > 0)
        with self.subTest():
            self.assertTrue(len(after_tables) > 0)
        with self.subTest():
            self.assertEqual(len(before_tables), len(after_tables))
        with self.subTest():
            drop_mock.assert_called()

    def test_get_mysql_dbs_1(self):
        """Verify set of databases is retrieved when engine
        is not connected to a specific database."""
        databases = mysqldb_basic.get_mysql_dbs(self.engine)
        self.assertTrue(DB in databases)

    def test_get_tables_1(self):
        """Verify set of tables is retrieved when engine
        is not connected to a specific database."""
        tables = mysqldb_basic.get_tables(self.engine, DB)
        self.assertTrue(TABLE in tables)

    def test_get_columns_1(self):
        """Verify set of columns is retrieved when engine
        is not connected to a specific database."""
        columns = mysqldb_basic.get_columns(self.engine, DB, TABLE)
        self.assertTrue(COLUMN in columns)
Exemplo n.º 11
0
class TestCompare(unittest.TestCase):
    @classmethod
    def setUpClass(self):
        test_db_utils.create_filled_test_db()

    @classmethod
    def tearDownClass(self):
        # Remove 'pdm_test_db'
        test_db_utils.remove_db()

    def setUp(self):
        test_folder.mkdir()

        # Standardize values in certain fields to define the data
        stmts = []
        stmts.append(create_update("phage", "Status", "draft"))
        stmts.append(create_update("phage", "Accession", ""))
        stmts.append(create_update("phage", "AnnotationAuthor", "0"))

        stmts.append(create_update("phage", "Accession", TRIXIE_ACC, "Trixie"))
        stmts.append(create_update("phage", "Accession", ALICE_ACC, "Alice"))
        stmts.append(create_update("phage", "Accession", L5_ACC, "L5"))
        stmts.append(create_update("phage", "Accession", TRIXIE_ACC, "D29"))

        stmts.append(create_update("phage", "Status", "final", "Trixie"))
        stmts.append(create_update("phage", "Status", "final", "Alice"))
        stmts.append(create_update("phage", "Status", "final", "L5"))
        stmts.append(create_update("phage", "Status", "final", "D29"))

        stmts.append(create_update("phage", "AnnotationAuthor", "1", "Trixie"))
        stmts.append(create_update("phage", "AnnotationAuthor", "1", "Alice"))
        stmts.append(create_update("phage", "AnnotationAuthor", "1", "L5"))
        stmts.append(create_update("phage", "AnnotationAuthor", "1", "D29"))

        for stmt in stmts:
            test_db_utils.execute(stmt)

        self.unparsed_args = get_unparsed_args()

        self.alchemist = AlchemyHandler(database=DB,
                                        username=USER,
                                        password=PWD)
        self.alchemist.build_engine()

        self.pdb_data1 = get_pdb_dict()
        self.pdb_data2 = get_pdb_dict()
        self.pdb_data3 = get_pdb_dict()
        self.pdb_data1["phage_name"] = "Trixie"
        self.pdb_data2["phage_name"] = "L5"
        self.pdb_data3["phage_name"] = "unmatched"

        json_results = [self.pdb_data1, self.pdb_data2, self.pdb_data3]
        self.pdb_json_data = get_pdb_json_data()
        self.pdb_json_data["results"] = json_results
        self.pdb_json_results = json_results

    def tearDown(self):
        shutil.rmtree(test_folder)

    def test_queries_1(self):
        """Verify hard-coded SQL queries are structured correctly."""

        version_data = test_db_utils.get_data(compare_db.VERSION_QUERY)
        phage_data = test_db_utils.get_data(compare_db.PHAGE_QUERY)
        gene_data = test_db_utils.get_data(compare_db.GENE_QUERY)

        ref_phage_data = test_db_utils.get_data(
            test_db_utils.phage_table_query)
        ref_gene_data = test_db_utils.get_data(test_db_utils.gene_table_query)

        with self.subTest():
            self.assertEqual(len(version_data), 1)
        with self.subTest():
            self.assertEqual(len(phage_data), len(ref_phage_data))
        with self.subTest():
            self.assertEqual(len(gene_data), len(ref_gene_data))

    # Calls to PhagesDB need to be mocked, since the pipeline downloads
    # and parses all sequenced genome data.
    @patch("pdm_utils.functions.phagesdb.retrieve_url_data")
    @patch("pdm_utils.functions.phagesdb.get_phagesdb_data")
    @patch("pdm_utils.pipelines.compare_db.AlchemyHandler")
    def test_main_1(self, alchemy_mock, gpd_mock, rud_mock):
        """Verify compare runs successfully with:
        MySQL, PhagesDB, and GenBank records saved,
        a duplicate MySQL accession (for D29 and Trixie),
        an invalid accession (for L5),
        a duplicate phage.Name (for Constance and Et2Brutus),
        a PhagesDB name unmatched to MySQL (for 'unmatched')."""
        alchemy_mock.return_value = self.alchemist
        gpd_mock.return_value = self.pdb_json_results
        rud_mock.return_value = FASTA_FILE

        # Make modifications to cause errors.
        stmts = []
        stmts.append(create_update("phage", "Accession", L5_ACC + "1", "L5"))
        stmts.append(create_update("phage", "Accession", TRIXIE_ACC, "D29"))

        stmts.append(create_update("phage", "Name", "Dupe", "Constance"))
        stmts.append(create_update("phage", "Name", "Dupe", "Et2Brutus"))

        stmts.append(create_update("phage", "PhageID", "Dupe", "Constance"))
        stmts.append(
            create_update("phage", "PhageID", "Dupe_Draft", "Et2Brutus"))
        for stmt in stmts:
            test_db_utils.execute(stmt)

        run.main(self.unparsed_args)
        count = count_files(test_folder)
        # input("check")
        with self.subTest():
            self.assertTrue(count > 0)
        with self.subTest():
            gpd_mock.assert_called()
        with self.subTest():
            rud_mock.assert_called()

    # Calls to PhagesDB need to be mocked, since the pipeline downloads
    # and parses all sequenced genome data.
    @patch("pdm_utils.pipelines.compare_db.get_pdb_data")
    @patch("pdm_utils.pipelines.compare_db.AlchemyHandler")
    def test_main_2(self, alchemy_mock, gpd_mock):
        """Verify duplicate PhagesDB names are identified."""
        # Clear accessions so that GenBank is not queried. No need for that.
        stmt = create_update("phage", "Accession", "")
        test_db_utils.execute(stmt)

        alchemy_mock.return_value = self.alchemist
        gpd_mock.return_value = ({}, {"L5"}, {"D29"})
        run.main(self.unparsed_args)
        count = count_files(test_folder)
        # input("check")
        with self.subTest():
            self.assertTrue(count > 0)
        with self.subTest():
            gpd_mock.assert_called()

    # Calls to PhagesDB need to be mocked, since the pipeline downloads
    # and parses all sequenced genome data.
    @patch("pdm_utils.pipelines.compare_db.get_pdb_data")
    @patch("pdm_utils.pipelines.compare_db.filter_mysql_genomes")
    @patch("pdm_utils.pipelines.compare_db.AlchemyHandler")
    def test_main_3(self, alchemy_mock, fmg_mock, gpd_mock):
        """Verify duplicate MySQL names are identified."""
        # Clear accessions so that GenBank is not queried. No need for that.
        stmt = create_update("phage", "Accession", "")
        test_db_utils.execute(stmt)

        alchemy_mock.return_value = self.alchemist
        fmg_mock.return_value = ({}, set(), {"L5"}, set(), set())
        gpd_mock.return_value = ({}, set(), set())

        run.main(self.unparsed_args)
        count = count_files(test_folder)
        # input("check")
        with self.subTest():
            self.assertTrue(count > 0)
        with self.subTest():
            fmg_mock.assert_called()
        with self.subTest():
            gpd_mock.assert_called()
Exemplo n.º 12
0
class TestAlchemyHandler(unittest.TestCase):
    def setUp(self):
        self.alchemist = AlchemyHandler()

    def test_constructor_1(self):
        self.assertEqual(self.alchemist._database, None)
        self.assertEqual(self.alchemist._username, None)
        self.assertEqual(self.alchemist._password, None)

    def test_constructor_2(self):
        self.assertEqual(self.alchemist._engine, None)
        self.assertEqual(self.alchemist.metadata, None)
        self.assertEqual(self.alchemist.graph, None)
        self.assertEqual(self.alchemist.session, None)

    def test_constructor_3(self):
        self.assertFalse(self.alchemist.connected)
        self.assertFalse(self.alchemist.has_database)
        self.assertFalse(self.alchemist.has_credentials)

    def test_database_1(self):
        self.alchemist.database = "Test"
        self.assertTrue(self.alchemist.has_database)
        self.assertFalse(self.alchemist.connected_database)

    def test_username_1(self):
        self.alchemist.username = "******"
        self.assertFalse(self.alchemist.has_credentials)
        self.assertFalse(self.alchemist.connected)

    def test_username_2(self):
        self.alchemist.password = "******"
        self.alchemist.username = "******"
        self.assertTrue(self.alchemist.has_credentials)
        self.assertFalse(self.alchemist.connected)

    def test_password_1(self):
        self.alchemist.password ="******"
        self.assertFalse(self.alchemist.has_credentials)
        self.assertFalse(self.alchemist.connected)

    def test_password_2(self):
        self.alchemist.username = "******"
        self.alchemist.password = "******"
        self.assertTrue(self.alchemist.has_credentials)
        self.assertFalse(self.alchemist.connected)
   
    def test_engine_1(self):
        self.alchemist.connected = True
        self.alchemist.engine = None

        self.assertFalse(self.alchemist.connected)
       
    def test_engine_2(self):
        with self.assertRaises(TypeError):
            self.alchemist.engine = "Test"

    @patch("pdm_utils.classes.alchemyhandler.input")
    def test_ask_database_1(self, Input):
        self.alchemist.ask_database()
        Input.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.input")
    def test_ask_database_2(self, Input):
        self.alchemist.has_database = False
        self.alchemist.connected = True

        self.alchemist.ask_database()
 
        self.assertTrue(self.alchemist.has_database)
        self.assertFalse(self.alchemist.connected)

    @patch("pdm_utils.classes.alchemyhandler.getpass")
    def test_ask_credentials_1(self, GetPass):
        self.alchemist.ask_credentials()

        GetPass.assert_called()
 
    @patch("pdm_utils.classes.alchemyhandler.getpass")
    def test_ask_credentials_2(self, GetPass):
        self.alchemist.has_credentials = False
        self.alchemist.connected = True

        self.alchemist.ask_credentials()

        self.assertTrue(self.alchemist.has_credentials)
        self.assertFalse(self.alchemist.connected)

    
    def test_validate_database_1(self):
        MockEngine = Mock()
        MockProxy = Mock()

        MockEngine.execute.return_value = MockProxy 
        MockProxy.fetchall.return_value = [("test_db",), 
                                           ("Actinobacteriophage",)]
 
        self.alchemist.database = "test_db"
        self.alchemist._engine = MockEngine

        self.alchemist.validate_database()

        MockEngine.execute.assert_called_with("SHOW DATABASES")
        MockProxy.fetchall.assert_called()

    def test_validate_database_2(self):
        with self.assertRaises(IndexError):
            self.alchemist.validate_database()

    def test_validate_database_3(self):
        MockEngine = Mock()
        MockProxy = Mock()

        MockEngine.execute.return_value = MockProxy
        MockProxy.fetchall.return_value = []

        self.alchemist.database = "test db"
        self.alchemist._engine = MockEngine

        with self.assertRaises(ValueError):
            self.alchemist.validate_database()

        MockEngine.execute.assert_called_with("SHOW DATABASES")
        MockProxy.fetchall.assert_called()


    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler."
                                                        "ask_credentials")
    @patch("pdm_utils.classes.alchemyhandler.sqlalchemy.create_engine")
    def test_build_engine_1(self, CreateEngine, AskCredentials):
        self.alchemist.engine = None
        self.alchemist.connected = True
        self.alchemist.build_engine()

        CreateEngine.assert_not_called()
        AskCredentials.assert_not_called()

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler."
                                                        "ask_credentials")
    @patch("pdm_utils.classes.alchemyhandler.sqlalchemy.create_engine")
    def test_build_engine_2(self, CreateEngine, AskCredentials):
        self.alchemist.username = "******"
        self.alchemist.password = "******"
        self.alchemist.has_credentials = False

        self.alchemist.build_engine()

        AskCredentials.assert_called()
        login_string = "mysql+pymysql://user:pass@localhost/"
        CreateEngine.assert_called_with(login_string)
   
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.validate_database")
    @patch("pdm_utils.classes.alchemyhandler.sqlalchemy.create_engine")
    def test_build_engine_3(self, CreateEngine, ValidateDatabase): 
        self.alchemist.username = "******"
        self.alchemist.password = "******"
        self.alchemist.database = "database"

        self.alchemist.build_engine()

        login_string = "mysql+pymysql://user:pass@localhost/"
        db_login_string = "mysql+pymysql://user:pass@localhost/database"

        CreateEngine.assert_any_call(login_string)
        CreateEngine.assert_any_call(db_login_string)

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler."
                                                        "ask_credentials")
    @patch("pdm_utils.classes.alchemyhandler.sqlalchemy.create_engine")
    def test_build_engine_4(self, CreateEngine, AskCredentials):
        self.alchemist.has_credentials = True
        self.alchemist.connected = False
        self.alchemist.metadata = "Test"
        self.alchemist.graph = "Test"

        self.alchemist.build_engine()

        self.alchemist.connected = True
        self.assertEqual(self.alchemist.metadata, None)
        self.assertEqual(self.alchemist.graph, None)

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler."
                                                        "ask_credentials")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.ask_database")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    def test_connect_1(self, BuildEngine, AskDatabase, AskCredentials):
        self.alchemist.has_credentials = True
        self.alchemist.connect()
        BuildEngine.assert_called()
        AskDatabase.assert_not_called()
        AskCredentials.assert_not_called()

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler."
                                                        "ask_credentials")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.ask_database")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    def test_connect_2(self, BuildEngine, AskDatabase, AskCredentials):
        self.alchemist.connect(ask_database=True)
        BuildEngine.assert_called()
        AskDatabase.assert_called()
        AskCredentials.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler."
                                                        "ask_credentials")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.ask_database")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    def test_connect_3(self, BuildEngine, AskDatabase, AskCredentials):
        self.alchemist.connected = False
        BuildEngine.side_effect = OperationalError("", "", "")
        
        with self.assertRaises(ValueError):
            self.alchemist.connect()

        BuildEngine.assert_called()
        AskDatabase.assert_not_called()
        AskCredentials.assert_called()

    def mock_build_engine(self, mock_engine):
        self.alchemist._engine = mock_engine

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    def test_execute_1(self, BuildEngine):
        MockEngine = Mock()
        MockProxy  = Mock()

        MockEngine.execute.return_value = MockProxy 
        MockProxy.fetchall.return_value = []

        self.alchemist._engine = MockEngine

        self.alchemist.execute("Executable")

        MockEngine.execute.assert_called_with("Executable")
        MockProxy.fetchall.assert_called()
        BuildEngine.assert_not_called() 
   
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    def test_scalar_1(self, BuildEngine):
        MockEngine = Mock()
        MockProxy  = Mock()
        
        MockEngine.execute.return_value = MockProxy
        MockProxy.scalar.return_value = "Scalar"

        self.alchemist._engine = MockEngine
       
        self.alchemist.scalar("Executable")

        MockEngine.execute.assert_called_with("Executable")
        MockProxy.scalar.assert_called()
        BuildEngine.assert_not_called()

    @patch("pdm_utils.classes.alchemyhandler.MetaData")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.ask_database")
    def test_build_metadata_1(self, AskDatabase, BuildEngine, MetaData):
        self.alchemist.has_database = False
        self.alchemist.connected = False

        self.alchemist.build_metadata()

        AskDatabase.assert_called()
        BuildEngine.assert_called()
        MetaData.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.MetaData")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.ask_database")
    def test_build_metadata_2(self, AskDatabase, BuildEngine, MetaData):
        self.alchemist.has_database = True
        self.alchemist.connected = True
        
        self.alchemist.build_metadata()

        AskDatabase.assert_not_called()
        BuildEngine.assert_not_called()
        MetaData.assert_called() 

    @patch("pdm_utils.classes.alchemyhandler.parsing.translate_table")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_translate_table_1(self, BuildMetadata, TranslateTable):
        self.alchemist.metadata = "Metadata"

        self.alchemist.translate_table("Test")

        TranslateTable.assert_called_with("Metadata", "Test")
        BuildMetadata.assert_not_called()

    @patch("pdm_utils.classes.alchemyhandler.parsing.translate_table")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_translate_table_2(self, BuildMetadata, TranslateTable):
        self.alchemist.metadata = None

        self.alchemist.translate_table("Test")

        TranslateTable.assert_called_with(None, "Test")
        BuildMetadata.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.parsing.translate_column") 
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_translate_column_1(self, BuildMetadata, TranslateColumn):
        self.alchemist.metadata = "Metadata"

        self.alchemist.translate_column("Test")

        TranslateColumn.assert_called_with("Metadata", "Test")
        BuildMetadata.assert_not_called()

    @patch("pdm_utils.classes.alchemyhandler.parsing.translate_column") 
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_translate_column_2(self, BuildMetadata, TranslateColumn):
        self.alchemist.metadata = None

        self.alchemist.translate_column("Test")

        TranslateColumn.assert_called_with(None, "Test")
        BuildMetadata.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.querying.get_table")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_get_table_1(self, BuildMetadata, GetTable):
        self.alchemist.metadata = "Metadata"

        self.alchemist.get_table("Test")

        GetTable.assert_called_with("Metadata", "Test")
        BuildMetadata.assert_not_called()

    @patch("pdm_utils.classes.alchemyhandler.querying.get_table")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_get_table_2(self, BuildMetadata, GetTable):
        self.alchemist.metadata = None

        self.alchemist.get_table("Test")

        GetTable.assert_called_with(None, "Test")
        BuildMetadata.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.querying.get_column") 
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_get_column_1(self, BuildMetadata, GetColumn):
        self.alchemist.metadata = "Metadata"

        self.alchemist.get_column("Test")

        GetColumn.assert_called_with("Metadata", "Test")
        BuildMetadata.assert_not_called()
        
    @patch("pdm_utils.classes.alchemyhandler.querying.get_column") 
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_get_column_2(self, BuildMetadata, GetColumn):
        self.alchemist.metadata = None

        self.alchemist.get_column("Test")

        GetColumn.assert_called_with(None, "Test")
        BuildMetadata.assert_called()
 
    @patch("pdm_utils.classes.alchemyhandler.querying.build_graph")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_build_graph_1(self, BuildMetadata, BuildGraph):
        BuildGraph.return_value = "Graph"

        self.alchemist.metadata = "Metadata"

        self.alchemist.build_graph()

        BuildMetadata.assert_not_called()
        BuildGraph.assert_called_with("Metadata")
        
        self.assertEqual(self.alchemist.graph, "Graph")

    @patch("pdm_utils.classes.alchemyhandler.querying.build_graph")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_build_graph_2(self, BuildMetadata, BuildGraph):
        BuildGraph.return_value = "Graph"

        self.alchemist.metadata = None

        self.alchemist.build_graph()

        BuildMetadata.assert_called()
        BuildGraph.assert_called_with(None)
        
        self.assertEqual(self.alchemist.graph, "Graph")

    @patch("pdm_utils.classes.alchemyhandler.cartography.get_map")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_get_map_1(self, BuildMetadata, GetMap): 
        self.alchemist.metadata = "Metadata"

        self.alchemist.get_map("Test")

        BuildMetadata.assert_not_called()
        GetMap.assert_called_with("Metadata", "Test")

    @patch("pdm_utils.classes.alchemyhandler.cartography.get_map")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_get_map_2(self, BuildMetadata, GetMap):
        self.alchemist.metadata = None

        self.alchemist.get_map("Test")

        BuildMetadata.assert_called()
        GetMap.assert_called_with(None, "Test") 
Exemplo n.º 13
0
class TestGetData(unittest.TestCase):

    @classmethod
    def setUpClass(self):
        test_db_utils.create_filled_test_db()

    @classmethod
    def tearDownClass(self):
        test_db_utils.remove_db()

    def setUp(self):

        self.alchemist = AlchemyHandler(database=DB, username=USER, password=PWD)
        self.alchemist.build_engine()
        test_folder.mkdir()

        # Standardize values in certain fields to define the data
        stmt1 = create_update("phage", "Status", "unknown")
        test_db_utils.execute(stmt1)
        stmt2 = create_update("phage", "HostGenus", "unknown")
        test_db_utils.execute(stmt2)
        stmt3 = create_update("phage", "Accession", "")
        test_db_utils.execute(stmt3)
        stmt4 = create_update("phage", "DateLastModified", "1900-01-01")
        test_db_utils.execute(stmt4)
        stmt5 = create_update("phage", "RetrieveRecord", "0")
        test_db_utils.execute(stmt5)

    def tearDown(self):
        shutil.rmtree(test_folder)




    @patch("pdm_utils.pipelines.get_data.AlchemyHandler")
    def test_main_1(self, alchemy_mock):
        """Verify update data and final data are retrieved."""
        # Testing the update flag and final flag have been merged so that
        # PhagesDB is only queried once for all data in the genome, since
        # it is time-intensive.
        alchemy_mock.return_value = self.alchemist
        # If final=True, any genome in database will be checked on PhagesDB
        # regardless of AnnotationAuthor
        unparsed_args = get_unparsed_args(update=True, final=True)
        run.main(unparsed_args)
        count1 = count_files(updates_folder)
        count2 = count_files(phagesdb_genome_folder)
        count3 = count_rows(update_table)
        count4 = count_rows(phagesdb_import_table)
        with self.subTest():
            self.assertEqual(count1, 1)
        with self.subTest():
            # It's not clear how stable the storage of any particular final
            # flat file is on PhagesDB. There is possibility that for any
            # genome the associated flat file will be removed. So this is
            # one area of testing that could be improved. For now, simply
            # verify that 1 or more files have been retrieved.
            self.assertTrue(count2 > 0)
        with self.subTest():
            # There should be several rows of updates.
            self.assertTrue(count3 > 0)
        with self.subTest():
            self.assertEqual(count2, count4)

    @patch("pdm_utils.pipelines.get_data.AlchemyHandler")
    def test_main_2(self, alchemy_mock):
        """Verify genbank data is retrieved."""
        alchemy_mock.return_value = self.alchemist
        stmt1 = create_update("phage", "RetrieveRecord", "1", phage_id="Trixie")
        test_db_utils.execute(stmt1)
        stmt2 = create_update("phage", "Accession", TRIXIE_ACC, phage_id="Trixie")
        test_db_utils.execute(stmt2)
        unparsed_args = get_unparsed_args(genbank=True, genbank_results=True)
        run.main(unparsed_args)
        query = "SELECT COUNT(*) FROM phage"
        count1 = test_db_utils.get_data(query)[0]["COUNT(*)"]
        count2 = count_files(genbank_genomes_folder)
        count3 = count_rows(genbank_import_table)
        count4 = count_rows(genbank_results_table)
        with self.subTest():
            self.assertEqual(count2, 1)
        with self.subTest():
            self.assertEqual(count2, count3)
        with self.subTest():
            self.assertEqual(count1, count4)

    @patch("pdm_utils.pipelines.get_data.match_genomes")
    @patch("pdm_utils.pipelines.get_data.AlchemyHandler")
    def test_main_3(self, alchemy_mock, mg_mock):
        """Verify draft data is retrieved."""
        matched_genomes = create_matched_genomes()
        alchemy_mock.return_value = self.alchemist
        mg_mock.return_value = (matched_genomes, {"EagleEye"})
        unparsed_args = get_unparsed_args(draft=True)
        run.main(unparsed_args)
        count1 = count_files(pecaan_genomes_folder)
        count2 = count_rows(pecaan_import_table)
        with self.subTest():
            self.assertEqual(count1, 1)
        with self.subTest():
            self.assertEqual(count1, count2)

    @patch("pdm_utils.pipelines.get_data.AlchemyHandler")
    def test_main_4(self, alchemy_mock):
        """Verify final data with very recent date are retrieved
        with force_download."""
        alchemy_mock.return_value = self.alchemist
        stmt = create_update("phage", "DateLastModified", "2200-01-01")
        test_db_utils.execute(stmt)
        unparsed_args = get_unparsed_args(final=True, force_download=True)
        run.main(unparsed_args)
        count = count_files(phagesdb_genome_folder)
        with self.subTest():
            # It's not clear how stable the storage of any particular final
            # flat file is on PhagesDB. There is possibility that for any
            # genome the associated flat file will be removed. So this is
            # one area of testing that could be improved. For now, simply
            # verify that 1 or more files have been retrieved.
            self.assertTrue(count > 0)

    @patch("pdm_utils.pipelines.get_data.match_genomes")
    @patch("pdm_utils.pipelines.get_data.AlchemyHandler")
    def test_main_5(self, alchemy_mock, mg_mock):
        """Verify draft data already in database is retrieved
        with force_download."""
        # Create a list of 2 matched genomes, only one of which has
        # status = draft.
        matched_genomes = create_matched_genomes()
        alchemy_mock.return_value = self.alchemist
        mg_mock.return_value = (matched_genomes, {"EagleEye"})
        unparsed_args = get_unparsed_args(draft=True, force_download=True)
        run.main(unparsed_args)
        count = count_files(pecaan_genomes_folder)
        self.assertEqual(count, 2)
Exemplo n.º 14
0
class TestAlchemyHandler(unittest.TestCase):
    def setUp(self):
        self.alchemist = AlchemyHandler()

    def test_constructor_1(self):
        """Verify AlchemyHandler credentials are initialized as None.
        """
        self.assertEqual(self.alchemist._database, None)
        self.assertEqual(self.alchemist._username, None)
        self.assertEqual(self.alchemist._password, None)

    def test_constructor_2(self):
        """Verify AlchemyHandler data objects are initialized as None.
        """
        self.assertEqual(self.alchemist._engine, None)
        self.assertEqual(self.alchemist._metadata, None)
        self.assertEqual(self.alchemist._graph, None)
        self.assertEqual(self.alchemist._session, None)

    def test_constructor_3(self):
        """Verify AlchemyHandler data booleans are initialized as False.
        """
        self.assertFalse(self.alchemist.connected)
        self.assertFalse(self.alchemist.has_database)
        self.assertFalse(self.alchemist.has_credentials)

    def test_database_1(self):
        """Verify database property sets has_database.
        """
        self.alchemist.database = "Test"
        self.assertTrue(self.alchemist.has_database)
        self.assertFalse(self.alchemist.connected_database)

    def test_username_1(self):
        """Verify username property conserves has_credentials and connected.
        """
        self.alchemist.username = "******"
        self.assertFalse(self.alchemist.has_credentials)
        self.assertFalse(self.alchemist.connected)

    def test_username_2(self):
        """Verify username property sets has_credentials with valid password.
        """
        self.alchemist.password = "******"
        self.alchemist.username = "******"
        self.assertTrue(self.alchemist.has_credentials)
        self.assertFalse(self.alchemist.connected)

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.clear")
    def test_username_3(self, clear_mock):
        """Verify changing usrename property calls clear().
        """
        self.alchemist.username = "******"

        clear_mock.assert_called()

    def test_password_1(self):
        """Verify password property conserves has_credentials and connected.
        """
        self.alchemist.password = "******"
        self.assertFalse(self.alchemist.has_credentials)
        self.assertFalse(self.alchemist.connected)

    def test_password_2(self):
        """Verify password property sets has_credentials with valid username.
        """
        self.alchemist.username = "******"
        self.alchemist.password = "******"
        self.assertTrue(self.alchemist.has_credentials)
        self.assertFalse(self.alchemist.connected)

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.clear")
    def test_password_3(self, clear_mock):
        """Verify changing password property calls clear().
        """
        self.alchemist.password = "******"

        clear_mock.assert_called()

    def test_construct_engine_string_1(self):
        """Verify construct_engine_string generates an expected URI.
        """
        URI = self.alchemist.construct_engine_string(username="******",
                                                     password="******")
        self.assertEqual(URI, "mysql+pymysql://pdm_user:pdm_pass@localhost/")

    def test_construct_engine_string_2(self):
        """Verify construct_engine_string accepts use of different drivers.
        """
        URI = self.alchemist.construct_engine_string(driver="mysqlconnector",
                                                     username="******",
                                                     password="******")

        self.assertEqual(URI,
                         "mysql+mysqlconnector://pdm_user:pdm_pass@localhost/")

    def test_engine_1(self):
        """Verify engine property sets connected.
        """
        self.alchemist.connected = True
        self.alchemist.engine = None

        self.assertFalse(self.alchemist.connected)

    def test_engine_2(self):
        """Verify engine property raises TypeError on bad engine input.
        """
        with self.assertRaises(TypeError):
            self.alchemist.engine = "Test"

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    def test_engine_3(self, build_engine_mock):
        """Verify engine property calls build_engine() selectively.
        """
        mock_engine = Mock()
        build_engine_mock.return_value = mock_engine

        self.alchemist._engine = "Test"
        self.assertEqual(self.alchemist.engine, "Test")

        build_engine_mock.assert_not_called()

        self.alchemist._engine = None
        self.alchemist.engine

        build_engine_mock.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler"
           ".extract_engine_credentials")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.get_mysql_dbs")
    def test_engine_4(self, get_mysql_dbs_mock,
                      extract_engine_credentials_mock):
        """Verify call structure of engine property setter.
        """
        mock_engine = Mock(spec=Engine)

        self.alchemist.engine = mock_engine

        get_mysql_dbs_mock.assert_called()
        extract_engine_credentials_mock.assert_called_with(mock_engine)

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_metadata_1(self, build_metadata_mock):
        """Verify metadata property calls build_metadata() selectively.
        """
        self.alchemist._metadata = "Test"
        self.alchemist.metadata

        build_metadata_mock.assert_not_called()

        self.alchemist._metadata = None
        self.alchemist.metadata

        build_metadata_mock.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_graph")
    def test_graph_1(self, build_graph_mock):
        """Verify graph property calls build_graph() selectively.
        """
        self.alchemist._graph = "Test"
        self.alchemist.graph

        build_graph_mock.assert_not_called()

        self.alchemist._graph = None
        self.alchemist.graph

        build_graph_mock.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_session")
    def test_session_1(self, build_session_mock):
        """Verify session property calls build_session() selectively.
        """
        self.alchemist._session = "Test"
        self.alchemist.session

        build_session_mock.assert_not_called()

        self.alchemist._session = None
        self.alchemist.session

        build_session_mock.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_mapper")
    def test_mapper_1(self, build_mapper_mock):
        """Verify mapper property calls build_mapper() selectively.
        """
        self.alchemist._mapper = "Test"
        self.alchemist.mapper

        build_mapper_mock.assert_not_called()

        self.alchemist._mapper = None
        self.alchemist.mapper

        build_mapper_mock.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.input")
    def test_ask_database_1(self, Input):
        """Verify ask_database() calls input().
        """
        self.alchemist.ask_database()
        Input.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.input")
    def test_ask_database_2(self, Input):
        """Verify ask_database() sets has_database.
        """
        self.alchemist.has_database = False
        self.alchemist.connected = True

        self.alchemist.ask_database()

        self.assertTrue(self.alchemist.has_database)
        self.assertFalse(self.alchemist.connected)

    @patch("pdm_utils.classes.alchemyhandler.getpass")
    def test_ask_credentials_1(self, GetPass):
        """Verify ask_credentials() calls getpass().
        """
        self.alchemist.ask_credentials()

        GetPass.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.getpass")
    def test_ask_credentials_2(self, GetPass):
        """Verify ask_credentials() sets has_credentials.
        """
        self.alchemist.has_credentials = False
        self.alchemist.connected = True

        self.alchemist.ask_credentials()

        self.assertTrue(self.alchemist.has_credentials)
        self.assertFalse(self.alchemist.connected)

    def test_validate_database_1(self):
        """Verify function structure of validate_database().
        """
        mock_engine = Mock()
        mock_proxy = Mock()

        mock_engine.execute.return_value = mock_proxy
        mock_proxy.fetchall.return_value = [("pdm_test_db",),
                                            ("Actinobacteriophage",)]

        self.alchemist.connected = True
        self.alchemist.database = "pdm_test_db"
        self.alchemist._engine = mock_engine

        self.alchemist.validate_database()

        mock_engine.execute.assert_called_once()
        mock_proxy.fetchall.assert_called()

    def test_validate_database_2(self):
        """Verify validate_database() raises IndexError without database.
        """
        self.alchemist.connected = True

        with self.assertRaises(AttributeError):
            self.alchemist.validate_database()

    def test_validate_database_3(self):
        """Verify validate_database() raises ValueError from bad database input.
        """
        mock_engine = Mock()
        mock_proxy = Mock()

        mock_engine.execute.return_value = mock_proxy
        mock_proxy.fetchall.return_value = []

        self.alchemist.connected = True
        self.alchemist.database = "test db"
        self.alchemist._engine = mock_engine

        with self.assertRaises(MySQLDatabaseError):
            self.alchemist.validate_database()

        mock_engine.execute.assert_called_once()
        mock_proxy.fetchall.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler."
           "ask_credentials")
    @patch("pdm_utils.classes.alchemyhandler.sqlalchemy.create_engine")
    def test_build_engine_1(self, create_engine_mock, ask_credentials_mock):
        """Verify build_engine() returns if connected already.
        """
        self.alchemist.engine = None
        self.alchemist.connected = True
        self.alchemist.build_engine()

        create_engine_mock.assert_not_called()
        ask_credentials_mock.assert_not_called()

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler."
           "ask_credentials")
    @patch("pdm_utils.classes.alchemyhandler.sqlalchemy.create_engine")
    def test_build_engine_2(self, create_engine_mock, ask_credentials_mock):
        """Verify build_engine() raises attribute error without credentials.
        """
        self.alchemist.username = "******"
        self.alchemist.password = "******"
        self.alchemist.has_credentials = False

        with self.assertRaises(AttributeError):
            self.alchemist.build_engine()

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.validate_database")
    @patch("pdm_utils.classes.alchemyhandler.sqlalchemy.create_engine")
    def test_build_engine_3(self, create_engine_mock, validate_database_mock):
        """Verify build_engine() calls create_engine() with db engine string.
        """
        self.alchemist.username = "******"
        self.alchemist.password = "******"
        self.alchemist.database = "database"

        self.alchemist.build_engine()

        login_string = "mysql+pymysql://user:pass@localhost/"
        db_login_string = "mysql+pymysql://user:pass@localhost/database"

        create_engine_mock.assert_any_call(login_string, echo=False)
        create_engine_mock.assert_any_call(db_login_string, echo=False)

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler."
           "ask_credentials")
    @patch("pdm_utils.classes.alchemyhandler.sqlalchemy.create_engine")
    def test_build_engine_4(self, create_engine_mock, ask_credentials_mock):
        """Verify build_engine() sets has_credentials.
        """
        self.alchemist.has_credentials = True
        self.alchemist.connected = False
        self.alchemist._metadata = "Test"
        self.alchemist._graph = "Test"

        self.alchemist.build_engine()

        self.alchemist.connected = True
        self.assertEqual(self.alchemist._metadata, None)
        self.assertEqual(self.alchemist._graph, None)

    @patch("pdm_utils.classes.alchemyhandler.sqlalchemy.create_engine")
    def test_build_engine_5(self, create_engine_mock):
        """Verify AlchemyHandler echo property controls create_engine()
        parameters.
        """
        self.alchemist.username = "******"
        self.alchemist.password = "******"
        self.alchemist.build_engine()

        login_string = "mysql+pymysql://user:pass@localhost/"

        create_engine_mock.assert_any_call(login_string, echo=False)

        self.alchemist.echo = True
        self.alchemist.connected = False
        self.alchemist.build_engine()

        create_engine_mock.assert_any_call(login_string, echo=True)

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler."
           "ask_credentials")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.ask_database")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    def test_connect_1(self, build_engine_mock, ask_database_mock,
                       AskCredentials):
        """Verify connect() returns if build_engine() does not complain.
        """
        self.alchemist.has_credentials = True
        self.alchemist.connected = True
        self.alchemist.connect()
        build_engine_mock.assert_called()
        ask_database_mock.assert_not_called()
        AskCredentials.assert_not_called()

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler."
           "ask_credentials")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.ask_database")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    def test_connect_2(self, build_engine_mock, ask_database_mock,
                       AskCredentials):
        """Verify connect() AlchemyHandler properties control function calls.
        """
        self.alchemist.connected = True
        self.alchemist.connected_database = True
        self.alchemist.has_credentials = True
        self.alchemist.connect(ask_database=True)
        build_engine_mock.assert_called()
        ask_database_mock.assert_not_called()
        AskCredentials.assert_not_called()

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler."
           "ask_credentials")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.ask_database")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    def test_connect_3(self, build_engine_mock, ask_database_mock,
                       AskCredentials):
        """Verify connect() depends on build_engine() to raise ValueError.
        """
        self.alchemist.connected = False
        build_engine_mock.side_effect = OperationalError("", "", "")

        with self.assertRaises(SQLCredentialsError):
            self.alchemist.connect()

        build_engine_mock.assert_called()
        ask_database_mock.assert_not_called()
        AskCredentials.assert_called()

    def build_engine_side_effect(self, mock_engine):
        """Helper function for side effect usage.
        """
        self.alchemist._engine = mock_engine

    @patch("pdm_utils.classes.alchemyhandler.MetaData")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.ask_database")
    def test_build_metadata_1(self, ask_database_mock, build_engine_mock,
                              metadata_mock):
        """Verify build_metadata() relies on AlchemyHandler properties.
        """
        self.alchemist.has_database = False
        self.alchemist.connected = False

        self.alchemist.build_metadata()

        ask_database_mock.assert_called()
        build_engine_mock.assert_called()
        metadata_mock.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.MetaData")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.ask_database")
    def test_build_metadata_2(self, ask_database_mock, build_engine_mock,
                              metadata_mock):
        """Verify build_metadata() calls ask_database() and build_engine().
        """
        self.alchemist.has_database = True
        self.alchemist.connected = True
        self.alchemist.connected_database = True

        self.alchemist.build_metadata()

        ask_database_mock.assert_not_called()
        build_engine_mock.assert_not_called()
        metadata_mock.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.querying.build_graph")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_build_graph_1(self, build_metadata_mock, build_graph_mock):
        """Verify build_graph() calls querying.build_graph().
        """
        build_graph_mock.return_value = "Graph"

        self.alchemist._metadata = "Metadata"

        self.alchemist.build_graph()

        build_metadata_mock.assert_not_called()
        build_graph_mock.assert_called_with("Metadata")

        self.assertEqual(self.alchemist._graph, "Graph")

    @patch("pdm_utils.classes.alchemyhandler.querying.build_graph")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_build_graph_2(self, build_metadata_mock, build_graph_mock):
        """Verify build_graph() calls build_metadata().
        """
        build_graph_mock.return_value = "Graph"

        self.alchemist._metadata = None

        self.alchemist.build_graph()

        build_metadata_mock.assert_called()
        build_graph_mock.assert_called_with(None)

        self.assertEqual(self.alchemist._graph, "Graph")

    @patch("pdm_utils.classes.alchemyhandler.sessionmaker")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.ask_database")
    def test_build_session_1(self, ask_database_mock, build_engine_mock,
                             sessionmaker_mock):
        """Verify build_session() relies on AlchemyHandler properties.
        """
        self.alchemist.has_database = False
        self.alchemist.connected = False

        self.alchemist.build_session()

        ask_database_mock.assert_called()
        build_engine_mock.assert_called()
        sessionmaker_mock.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.sessionmaker")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.ask_database")
    def test_build_session_2(self, ask_database_mock, build_engine_mock,
                             sessionmaker_mock):
        """Verify build_session() calls ask_database() and build_engine().
        """
        self.alchemist.has_database = True
        self.alchemist.connected = True

        self.alchemist.build_session()

        ask_database_mock.assert_not_called()
        build_engine_mock.assert_not_called()
        sessionmaker_mock.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.automap_base")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_build_mapper_1(self, build_metadata_mock, automap_base_mock):
        """Verify build_mapper() calls automap_base().
        """
        base_mock = Mock()
        automap_base_mock.return_value = base_mock

        self.alchemist._metadata = "Metadata"

        self.alchemist.build_mapper()

        build_metadata_mock.assert_not_called()
        automap_base_mock.assert_called_with(metadata="Metadata")

        self.assertEqual(self.alchemist._mapper, base_mock)

    @patch("pdm_utils.classes.alchemyhandler.automap_base")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_build_mapper_2(self, build_metadata_mock, automap_base_mock):
        """Verify build_mapper() calls build_metadata().
        """
        base_mock = Mock()
        automap_base_mock.return_value = base_mock

        self.alchemist._metadata = None

        self.alchemist.build_mapper()

        build_metadata_mock.assert_called()
        automap_base_mock.assert_called_with(metadata=None)

        self.assertEqual(self.alchemist._mapper, base_mock)
Exemplo n.º 15
0
class TestPhamerationFunctions(unittest.TestCase):
    def setUp(self):

        # Create test database that contains data for several phages.
        test_db_utils.create_filled_test_db()

        self.alchemist = AlchemyHandler(database=DB,
                                        username=USER,
                                        password=PWD)
        self.alchemist.build_engine()
        self.engine = self.alchemist.engine
        self.temp_dir = "/tmp/pdm_utils_tests_phamerate"

    def tearDown(self):
        self.engine.dispose()
        test_db_utils.remove_db()

        run_dir = Path.cwd()
        err_file = run_dir.joinpath("error.log")
        if err_file.exists():
            print("Found leftover blastclust file... removing")
            err_file.unlink()

    def test_1_get_pham_geneids(self):
        """Verify we get back a dictionary"""
        old_phams = get_pham_geneids(self.engine)
        # old_phams should be a dict
        self.assertEqual(type(old_phams), type(dict()))

    def test_2_get_pham_colors(self):
        """Verify we get back a dictionary"""
        old_colors = get_pham_colors(self.engine)
        # old_colors should be a dict
        self.assertEqual(type(old_colors), type(dict()))

    def test_3_get_pham_geneids_and_colors(self):
        """Verify both dictionaries have the same keys"""
        old_phams = get_pham_geneids(self.engine)
        old_colors = get_pham_colors(self.engine)

        # Can't have same keys without the same number of keys...
        with self.subTest():
            self.assertEqual(len(old_phams), len(old_colors))

        # Intersection should be equal to either set of keys - check against old_phams
        with self.subTest():
            self.assertEqual(
                set(old_phams.keys()).intersection(set(old_colors.keys())),
                set(old_phams.keys()))

    def test_4_get_unphamerated_genes(self):
        """Verify we get back a set of length 0"""
        unphamerated = get_new_geneids(self.engine)
        # unphamerated should be a set
        with self.subTest():
            self.assertEqual(type(unphamerated), type(set()))
        # pdm_test_db has 0 unphamerated genes
        with self.subTest():
            self.assertEqual(len(unphamerated), 0)

    def test_5_get_geneids_and_translations(self):
        """Verify we get back a dictionary"""
        gs_to_ts = get_geneids_and_translations(self.engine)

        command = "SELECT distinct(GeneID) FROM gene"
        results = mysqldb_basic.query_dict_list(self.engine, command)

        # gs_to_ts should be a dictionary
        with self.subTest():
            self.assertEqual(type(gs_to_ts), type(dict()))
        # gs_to_ts should have the right number of geneids
        with self.subTest():
            self.assertEqual(len(gs_to_ts), len(results))

    def test_6_get_translation_groups(self):
        """Verify we get back a dictionary"""
        ts_to_gs = get_translation_groups(self.engine)

        command = "SELECT distinct(CONVERT(Translation USING utf8)) FROM gene"
        results = mysqldb_basic.query_dict_list(self.engine, command)

        # ts_to_gs should be a dictionary
        with self.subTest():
            self.assertEqual(type(ts_to_gs), type(dict()))
        # ts_to_gs should have the right number of translations
        with self.subTest():
            self.assertEqual(len(ts_to_gs), len(results))

    def test_7_refresh_tempdir_1(self):
        """Verify if no temp_dir, refresh can make one"""
        if not os.path.exists(self.temp_dir):
            refresh_tempdir(self.temp_dir)
        self.assertTrue(os.path.exists(self.temp_dir))

    def test_8_refresh_tempdir_2(self):
        """Verify if temp_dir with something, refresh makes new empty one"""
        filename = f"{self.temp_dir}/test.txt"
        if not os.path.exists(self.temp_dir):
            refresh_tempdir(self.temp_dir)
        f = open(filename, "w")
        f.write("test\n")
        f.close()

        # Our test file should now exist
        with self.subTest():
            self.assertTrue(os.path.exists(filename))

        # Refresh temp_dir
        refresh_tempdir(self.temp_dir)

        # temp_dir should now exist, but test file should not
        with self.subTest():
            self.assertTrue(os.path.exists(self.temp_dir))
        with self.subTest():
            self.assertFalse(os.path.exists(filename))

    def test_9_write_fasta(self):
        """Verify file gets written properly"""
        filename = f"{self.temp_dir}/input.fasta"

        # refresh_tempdir
        refresh_tempdir(self.temp_dir)

        # Get translations to geneid mappings
        ts_to_gs = get_translation_groups(self.engine)

        # Write fasta
        write_fasta(ts_to_gs, filename)

        # Read fasta, make sure number of lines is 2x number of unique translations
        with open(filename, "r") as fh:
            lines = fh.readlines()

        with self.subTest():
            self.assertEqual(len(lines), 2 * len(ts_to_gs))

        # all odd-index lines should map to a key in ts_to_gs
        for i in range(len(lines)):
            if i % 2 == 1:
                with self.subTest():
                    self.assertTrue(
                        lines[i].lstrip(">").rstrip() in ts_to_gs.keys())

    def test_10_create_blastdb(self):
        """Verify blast protein database gets made"""
        filename = f"{self.temp_dir}/input.fasta"
        db_name = "blastdb"
        db_path = f"{self.temp_dir}/blastdb"

        refresh_tempdir(self.temp_dir)

        ts_to_gs = get_translation_groups(self.engine)
        write_fasta(ts_to_gs, filename)

        create_blastdb(filename, db_name, db_path)

        # Check that database files were made
        for ext in ["phr", "pin", "pog", "psd", "psi", "psq"]:
            with self.subTest():
                self.assertTrue(os.path.exists(f"{db_path}.{ext}"))

    def test_11_create_mmseqsdb(self):
        """Verify mmseqs database gets made"""
        filename = f"{self.temp_dir}/input.fasta"
        db_file = f"{self.temp_dir}/sequenceDB"

        refresh_tempdir(self.temp_dir)

        ts_to_gs = get_translation_groups(self.engine)
        write_fasta(ts_to_gs, filename)

        mmseqs_createdb(filename, db_file)

        # Check that database file was made
        self.assertTrue(os.path.exists(db_file))