コード例 #1
0
ファイル: get_db.py プロジェクト: stjacqrm/pdm_utils
def install_db(database, db_filepath=None, schema_version=None):
    """Install database. If database already exists, it is first removed."""
    # No need to specify database yet, since it needs to first check if the
    # database exists.

    alchemist1 = AlchemyHandler(database="")
    alchemist1.connect(pipeline=True)
    engine1 = alchemist1.engine
    result = mysqldb_basic.drop_create_db(engine1, database)
    if result != 0:
        print("Unable to create new, empty database.")
    else:
        alchemist2 = AlchemyHandler(database=database,
                                    username=engine1.url.username,
                                    password=engine1.url.password)
        alchemist2.connect(pipeline=True)
        engine2 = alchemist2.engine
        if engine2 is None:
            print(f"No connection to the {database} database due "
                  "to invalid credentials or database.")
        else:
            if db_filepath is not None:
                mysqldb_basic.install_db(engine2, db_filepath)
            else:
                mysqldb.execute_transaction(engine2, db_schema_0.STATEMENTS)
                convert_args = [
                    "pdm_utils.run", "convert", database, "-s",
                    str(schema_version)
                ]
                convert_db.main(convert_args, engine2)
            # Close up all connections in the connection pool.
            engine2.dispose()
    # Close up all connections in the connection pool.
    engine1.dispose()
コード例 #2
0
    def setUp(self):
        if not test_db_utils.check_if_exists():
            test_db_utils.create_empty_test_db()

        alchemist = AlchemyHandler(username=USER, password=PWD, database=DB)
        alchemist.connect()
        self.engine = alchemist.engine
コード例 #3
0
def establish_database_connection(database_name: str):
    if not isinstance(database_name, str):
        print("establish_database_connection requires string input")
        raise TypeError
    alchemist = AlchemyHandler(database=database_name)
    alchemist.connect()

    return alchemist
コード例 #4
0
ファイル: get_gb_records.py プロジェクト: stjacqrm/pdm_utils
def main(unparsed_args_list):
    """Run main get_gb_records pipeline."""
    # Parse command line arguments
    args = parse_args(unparsed_args_list)

    # Filters input: phage.Status=draft AND phage.HostGenus=Mycobacterium
    # Args structure: [['phage.Status=draft'], ['phage.HostGenus=Mycobacterium']]
    filters = args.filters
    ncbi_cred_dict = ncbi.get_ncbi_creds(args.ncbi_credentials_file)
    output_folder = basic.set_path(args.output_folder, kind="dir", expect=True)
    working_dir = pathlib.Path(RESULTS_FOLDER)
    working_path = basic.make_new_dir(output_folder, working_dir, attempt=50)
    if working_path is None:
        print(f"Invalid working directory '{working_dir}'")
        sys.exit(1)

    # Verify database connection and schema compatibility.
    print("Connecting to the MySQL database...")
    alchemist = AlchemyHandler(database=args.database)
    alchemist.connect(pipeline=True)
    engine = alchemist.engine
    mysqldb.check_schema_compatibility(engine, "the get_gb_records pipeline")

    # Get SQLAlchemy metadata Table object
    # table_obj.primary_key.columns is a
    # SQLAlchemy ColumnCollection iterable object
    # Set primary key = 'phage.PhageID'
    alchemist.build_metadata()
    table = querying.get_table(alchemist.metadata, TARGET_TABLE)
    for column in table.primary_key.columns:
        primary_key = column

    # Create filter object and then add command line filter strings
    db_filter = Filter(alchemist=alchemist, key=primary_key)
    db_filter.values = []

    # Attempt to add filters and exit if needed.
    add_filters(db_filter, filters)

    # Performs the query
    db_filter.update()

    # db_filter.values now contains list of PhageIDs that pass the filters.
    # Get the accessions associated with these PhageIDs.
    keep_set = set(db_filter.values)

    # Create data sets
    print("Retrieving accessions from the database...")
    query = construct_accession_query(keep_set)
    list_of_dicts = mysqldb_basic.query_dict_list(engine, query)
    id_acc_dict = get_id_acc_dict(list_of_dicts)
    acc_id_dict = get_acc_id_dict(id_acc_dict)
    engine.dispose()
    if len(acc_id_dict.keys()) > 0:
        get_data(working_path, acc_id_dict, ncbi_cred_dict)
    else:
        print("There are no records to retrieve.")
コード例 #5
0
def main(unparsed_args):
    """Runs the complete update pipeline."""
    args = parse_args(unparsed_args[2:])

    # Verify database connection and schema compatibility.
    print("Connecting to the MySQL database...")

    # Create config object with data obtained from file and/or defaults.
    config = configfile.build_complete_config(args.config_file)
    mysql_creds = config["mysql"]
    alchemist = AlchemyHandler(database=args.database,
                               username=mysql_creds["user"],
                               password=mysql_creds["password"])
    alchemist.connect(pipeline=True)
    engine = alchemist.engine
    mysqldb.check_schema_compatibility(engine, "the update pipeline")

    if args.version is True:
        mysqldb.change_version(engine)
        print("Database version updated.")

    if args.ticket_table is not None:
        update_table_path = basic.set_path(args.ticket_table,
                                           kind="file",
                                           expect=True)

        # Iterate through the tickets and process them sequentially.
        list_of_update_tickets = []
        with update_table_path.open(mode='r') as f:
            file_reader = csv.DictReader(f)
            for dict in file_reader:
                list_of_update_tickets.append(dict)

        # Variables to be used for end summary
        processed = 0
        succeeded = 0
        failed = 0

        for dict in list_of_update_tickets:
            status = update_field(alchemist, dict)

            if status == 1:
                processed += 1
                succeeded += 1
            else:
                processed += 1
                failed += 1

        engine.dispose()
        print("\nDone iterating through tickets.")
        if succeeded > 0:
            print(f"{succeeded} / {processed} tickets successfully handled.")
        if failed > 0:
            print(f"{failed} / {processed} tickets failed to be handled.")
コード例 #6
0
    def setUp(self):
        alchemist = AlchemyHandler()
        alchemist.username = "******"
        alchemist.password = "******"
        alchemist.database = "test_db"
        alchemist.connect()
        alchemist.build_graph()
        self.alchemist = alchemist

        self.db_filter = Filter(alchemist=self.alchemist)

        phageid = self.alchemist.get_column("phage.PhageID")
        self.phageid = phageid
コード例 #7
0
def main(unparsed_args_list):
    """Uses parsed args to run the entirety of the review pipeline.

    :param unparsed_args_list: Input a list of command line args.
    :type unparsed_args_list: list[str]
    """
    args = parse_review(unparsed_args_list)

    alchemist = AlchemyHandler(database=args.database)
    alchemist.connect(ask_database=True, pipeline=True)

    values = export_db.parse_value_input(args.input)
    
    execute_review(alchemist, args.folder_path, args.folder_name,
                   review=args.review, values=values,
                   filters=args.filters, groups=args.groups, sort=args.sort,
                   g_reports=args.gene_reports, s_report=args.summary_report,
                   verbose=args.verbose)
コード例 #8
0
def main(unparsed_args_list):
    """Uses parsed args to run the entirety of the file export pipeline.

    :param unparsed_args_list: Input a list of command line args.
    :type unparsed_args_list: list[str]
    """
    #Returns after printing appropriate error message from parsing/connecting.
    args = parse_export(unparsed_args_list)

    alchemist = AlchemyHandler(database=args.database)
    alchemist.connect(ask_database=True, pipeline=True)
    alchemist.build_graph()

    # Exporting as a SQL file is not constricted by schema version.
    if args.pipeline != "sql":
        mysqldb.check_schema_compatibility(alchemist.engine, "export")

    values = []
    if args.pipeline in FILTERABLE_PIPELINES:
        values = parse_value_input(args.input)

    if not args.pipeline in PIPELINES:
        print("ABORTED EXPORT: Unknown pipeline option discrepency.\n"
              "Pipeline parsed from command line args is not supported")
        sys.exit(1)

    if args.pipeline != "I":
        execute_export(alchemist,
                       args.folder_path,
                       args.folder_name,
                       args.pipeline,
                       table=args.table,
                       values=values,
                       filters=args.filters,
                       groups=args.groups,
                       sort=args.sort,
                       include_columns=args.include_columns,
                       exclude_columns=args.exclude_columns,
                       sequence_columns=args.sequence_columns,
                       raw_bytes=args.raw_bytes,
                       concatenate=args.concatenate,
                       verbose=args.verbose)
    else:
        pass
コード例 #9
0
    def setUp(self):
        alchemist = AlchemyHandler()
        alchemist.username=user
        alchemist.password=pwd
        alchemist.database=db
        alchemist.connect()
        self.alchemist = alchemist

        self.db_filter = Filter(alchemist=self.alchemist)

        self.phage = self.alchemist.metadata.tables["phage"]
        self.gene = self.alchemist.metadata.tables["gene"]
        self.trna = self.alchemist.metadata.tables["trna"]

        self.PhageID = self.phage.c.PhageID
        self.Cluster = self.phage.c.Cluster
        self.Subcluster = self.phage.c.Subcluster
        
        self.Notes = self.gene.c.Notes
コード例 #10
0
def main(unparsed_args_list):
    """Uses parsed args to run the entirety of the resubmit pipeline.

    :param unparsed_args_list: Input a list of command line args.
    :type unparsed_args_list: list[str]
    """
    args = parse_resubmit(unparsed_args_list)

    alchemist = AlchemyHandler(database=args.database)
    alchemist.connect(ask_database=True, pipeline=True)

    revisions_data_dicts = basic.retrieve_data_dict(args.revisions_file)

    execute_resubmit(alchemist,
                     revisions_data_dicts,
                     args.folder_path,
                     args.folder_name,
                     filters=args.filters,
                     groups=args.groups,
                     verbose=args.verbose)
コード例 #11
0
    def connect(self, alchemist=None):
        """Connect Filter object to a database with an AlchemyHandler.

        :param alchemist: An AlchemyHandler object.
        :type alchemist: AlchemyHandler
        """
        if alchemist != None:
            self.link(alchemist)
            return

        if self._connected:
            return

        alchemist = AlchemyHandler()
        alchemist.connect(ask_database=True)

        self._engine = alchemist.engine
        self._graph = alchemist.graph
        self._session = alchemist.session
        self._mapper = alchemist.mapper

        self._connected = True
コード例 #12
0
def build_alchemist(database, ask_database=True, config=None, dialect="mysql"):
    if config is not None:
        username = config["mysql"].get("user")
        password = config["mysql"].get("password")
        if not (username is None or password is None):
            alchemist = AlchemyHandler(username=username, password=password,
                                       dialect=dialect)
            alchemist.connect(login_attempts=0, pipeline=True)

            alchemist.database = database
            alchemist.connect(ask_database=ask_database, pipeline=True)

            return alchemist

    alchemist = AlchemyHandler(database=database)
    alchemist.connect(ask_database=ask_database, pipeline=True)

    return alchemist
コード例 #13
0
class TestAlchemyHandler(unittest.TestCase):
    def setUp(self):
        self.alchemist = AlchemyHandler()

    def test_constructor_1(self):
        self.assertEqual(self.alchemist._database, None)
        self.assertEqual(self.alchemist._username, None)
        self.assertEqual(self.alchemist._password, None)

    def test_constructor_2(self):
        self.assertEqual(self.alchemist._engine, None)
        self.assertEqual(self.alchemist.metadata, None)
        self.assertEqual(self.alchemist.graph, None)
        self.assertEqual(self.alchemist.session, None)

    def test_constructor_3(self):
        self.assertFalse(self.alchemist.connected)
        self.assertFalse(self.alchemist.has_database)
        self.assertFalse(self.alchemist.has_credentials)

    def test_database_1(self):
        self.alchemist.database = "Test"
        self.assertTrue(self.alchemist.has_database)
        self.assertFalse(self.alchemist.connected_database)

    def test_username_1(self):
        self.alchemist.username = "******"
        self.assertFalse(self.alchemist.has_credentials)
        self.assertFalse(self.alchemist.connected)

    def test_username_2(self):
        self.alchemist.password = "******"
        self.alchemist.username = "******"
        self.assertTrue(self.alchemist.has_credentials)
        self.assertFalse(self.alchemist.connected)

    def test_password_1(self):
        self.alchemist.password ="******"
        self.assertFalse(self.alchemist.has_credentials)
        self.assertFalse(self.alchemist.connected)

    def test_password_2(self):
        self.alchemist.username = "******"
        self.alchemist.password = "******"
        self.assertTrue(self.alchemist.has_credentials)
        self.assertFalse(self.alchemist.connected)
   
    def test_engine_1(self):
        self.alchemist.connected = True
        self.alchemist.engine = None

        self.assertFalse(self.alchemist.connected)
       
    def test_engine_2(self):
        with self.assertRaises(TypeError):
            self.alchemist.engine = "Test"

    @patch("pdm_utils.classes.alchemyhandler.input")
    def test_ask_database_1(self, Input):
        self.alchemist.ask_database()
        Input.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.input")
    def test_ask_database_2(self, Input):
        self.alchemist.has_database = False
        self.alchemist.connected = True

        self.alchemist.ask_database()
 
        self.assertTrue(self.alchemist.has_database)
        self.assertFalse(self.alchemist.connected)

    @patch("pdm_utils.classes.alchemyhandler.getpass")
    def test_ask_credentials_1(self, GetPass):
        self.alchemist.ask_credentials()

        GetPass.assert_called()
 
    @patch("pdm_utils.classes.alchemyhandler.getpass")
    def test_ask_credentials_2(self, GetPass):
        self.alchemist.has_credentials = False
        self.alchemist.connected = True

        self.alchemist.ask_credentials()

        self.assertTrue(self.alchemist.has_credentials)
        self.assertFalse(self.alchemist.connected)

    
    def test_validate_database_1(self):
        MockEngine = Mock()
        MockProxy = Mock()

        MockEngine.execute.return_value = MockProxy 
        MockProxy.fetchall.return_value = [("test_db",), 
                                           ("Actinobacteriophage",)]
 
        self.alchemist.database = "test_db"
        self.alchemist._engine = MockEngine

        self.alchemist.validate_database()

        MockEngine.execute.assert_called_with("SHOW DATABASES")
        MockProxy.fetchall.assert_called()

    def test_validate_database_2(self):
        with self.assertRaises(IndexError):
            self.alchemist.validate_database()

    def test_validate_database_3(self):
        MockEngine = Mock()
        MockProxy = Mock()

        MockEngine.execute.return_value = MockProxy
        MockProxy.fetchall.return_value = []

        self.alchemist.database = "test db"
        self.alchemist._engine = MockEngine

        with self.assertRaises(ValueError):
            self.alchemist.validate_database()

        MockEngine.execute.assert_called_with("SHOW DATABASES")
        MockProxy.fetchall.assert_called()


    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler."
                                                        "ask_credentials")
    @patch("pdm_utils.classes.alchemyhandler.sqlalchemy.create_engine")
    def test_build_engine_1(self, CreateEngine, AskCredentials):
        self.alchemist.engine = None
        self.alchemist.connected = True
        self.alchemist.build_engine()

        CreateEngine.assert_not_called()
        AskCredentials.assert_not_called()

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler."
                                                        "ask_credentials")
    @patch("pdm_utils.classes.alchemyhandler.sqlalchemy.create_engine")
    def test_build_engine_2(self, CreateEngine, AskCredentials):
        self.alchemist.username = "******"
        self.alchemist.password = "******"
        self.alchemist.has_credentials = False

        self.alchemist.build_engine()

        AskCredentials.assert_called()
        login_string = "mysql+pymysql://user:pass@localhost/"
        CreateEngine.assert_called_with(login_string)
   
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.validate_database")
    @patch("pdm_utils.classes.alchemyhandler.sqlalchemy.create_engine")
    def test_build_engine_3(self, CreateEngine, ValidateDatabase): 
        self.alchemist.username = "******"
        self.alchemist.password = "******"
        self.alchemist.database = "database"

        self.alchemist.build_engine()

        login_string = "mysql+pymysql://user:pass@localhost/"
        db_login_string = "mysql+pymysql://user:pass@localhost/database"

        CreateEngine.assert_any_call(login_string)
        CreateEngine.assert_any_call(db_login_string)

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler."
                                                        "ask_credentials")
    @patch("pdm_utils.classes.alchemyhandler.sqlalchemy.create_engine")
    def test_build_engine_4(self, CreateEngine, AskCredentials):
        self.alchemist.has_credentials = True
        self.alchemist.connected = False
        self.alchemist.metadata = "Test"
        self.alchemist.graph = "Test"

        self.alchemist.build_engine()

        self.alchemist.connected = True
        self.assertEqual(self.alchemist.metadata, None)
        self.assertEqual(self.alchemist.graph, None)

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler."
                                                        "ask_credentials")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.ask_database")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    def test_connect_1(self, BuildEngine, AskDatabase, AskCredentials):
        self.alchemist.has_credentials = True
        self.alchemist.connect()
        BuildEngine.assert_called()
        AskDatabase.assert_not_called()
        AskCredentials.assert_not_called()

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler."
                                                        "ask_credentials")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.ask_database")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    def test_connect_2(self, BuildEngine, AskDatabase, AskCredentials):
        self.alchemist.connect(ask_database=True)
        BuildEngine.assert_called()
        AskDatabase.assert_called()
        AskCredentials.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler."
                                                        "ask_credentials")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.ask_database")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    def test_connect_3(self, BuildEngine, AskDatabase, AskCredentials):
        self.alchemist.connected = False
        BuildEngine.side_effect = OperationalError("", "", "")
        
        with self.assertRaises(ValueError):
            self.alchemist.connect()

        BuildEngine.assert_called()
        AskDatabase.assert_not_called()
        AskCredentials.assert_called()

    def mock_build_engine(self, mock_engine):
        self.alchemist._engine = mock_engine

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    def test_execute_1(self, BuildEngine):
        MockEngine = Mock()
        MockProxy  = Mock()

        MockEngine.execute.return_value = MockProxy 
        MockProxy.fetchall.return_value = []

        self.alchemist._engine = MockEngine

        self.alchemist.execute("Executable")

        MockEngine.execute.assert_called_with("Executable")
        MockProxy.fetchall.assert_called()
        BuildEngine.assert_not_called() 
   
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    def test_scalar_1(self, BuildEngine):
        MockEngine = Mock()
        MockProxy  = Mock()
        
        MockEngine.execute.return_value = MockProxy
        MockProxy.scalar.return_value = "Scalar"

        self.alchemist._engine = MockEngine
       
        self.alchemist.scalar("Executable")

        MockEngine.execute.assert_called_with("Executable")
        MockProxy.scalar.assert_called()
        BuildEngine.assert_not_called()

    @patch("pdm_utils.classes.alchemyhandler.MetaData")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.ask_database")
    def test_build_metadata_1(self, AskDatabase, BuildEngine, MetaData):
        self.alchemist.has_database = False
        self.alchemist.connected = False

        self.alchemist.build_metadata()

        AskDatabase.assert_called()
        BuildEngine.assert_called()
        MetaData.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.MetaData")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.ask_database")
    def test_build_metadata_2(self, AskDatabase, BuildEngine, MetaData):
        self.alchemist.has_database = True
        self.alchemist.connected = True
        
        self.alchemist.build_metadata()

        AskDatabase.assert_not_called()
        BuildEngine.assert_not_called()
        MetaData.assert_called() 

    @patch("pdm_utils.classes.alchemyhandler.parsing.translate_table")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_translate_table_1(self, BuildMetadata, TranslateTable):
        self.alchemist.metadata = "Metadata"

        self.alchemist.translate_table("Test")

        TranslateTable.assert_called_with("Metadata", "Test")
        BuildMetadata.assert_not_called()

    @patch("pdm_utils.classes.alchemyhandler.parsing.translate_table")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_translate_table_2(self, BuildMetadata, TranslateTable):
        self.alchemist.metadata = None

        self.alchemist.translate_table("Test")

        TranslateTable.assert_called_with(None, "Test")
        BuildMetadata.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.parsing.translate_column") 
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_translate_column_1(self, BuildMetadata, TranslateColumn):
        self.alchemist.metadata = "Metadata"

        self.alchemist.translate_column("Test")

        TranslateColumn.assert_called_with("Metadata", "Test")
        BuildMetadata.assert_not_called()

    @patch("pdm_utils.classes.alchemyhandler.parsing.translate_column") 
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_translate_column_2(self, BuildMetadata, TranslateColumn):
        self.alchemist.metadata = None

        self.alchemist.translate_column("Test")

        TranslateColumn.assert_called_with(None, "Test")
        BuildMetadata.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.querying.get_table")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_get_table_1(self, BuildMetadata, GetTable):
        self.alchemist.metadata = "Metadata"

        self.alchemist.get_table("Test")

        GetTable.assert_called_with("Metadata", "Test")
        BuildMetadata.assert_not_called()

    @patch("pdm_utils.classes.alchemyhandler.querying.get_table")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_get_table_2(self, BuildMetadata, GetTable):
        self.alchemist.metadata = None

        self.alchemist.get_table("Test")

        GetTable.assert_called_with(None, "Test")
        BuildMetadata.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.querying.get_column") 
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_get_column_1(self, BuildMetadata, GetColumn):
        self.alchemist.metadata = "Metadata"

        self.alchemist.get_column("Test")

        GetColumn.assert_called_with("Metadata", "Test")
        BuildMetadata.assert_not_called()
        
    @patch("pdm_utils.classes.alchemyhandler.querying.get_column") 
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_get_column_2(self, BuildMetadata, GetColumn):
        self.alchemist.metadata = None

        self.alchemist.get_column("Test")

        GetColumn.assert_called_with(None, "Test")
        BuildMetadata.assert_called()
 
    @patch("pdm_utils.classes.alchemyhandler.querying.build_graph")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_build_graph_1(self, BuildMetadata, BuildGraph):
        BuildGraph.return_value = "Graph"

        self.alchemist.metadata = "Metadata"

        self.alchemist.build_graph()

        BuildMetadata.assert_not_called()
        BuildGraph.assert_called_with("Metadata")
        
        self.assertEqual(self.alchemist.graph, "Graph")

    @patch("pdm_utils.classes.alchemyhandler.querying.build_graph")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_build_graph_2(self, BuildMetadata, BuildGraph):
        BuildGraph.return_value = "Graph"

        self.alchemist.metadata = None

        self.alchemist.build_graph()

        BuildMetadata.assert_called()
        BuildGraph.assert_called_with(None)
        
        self.assertEqual(self.alchemist.graph, "Graph")

    @patch("pdm_utils.classes.alchemyhandler.cartography.get_map")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_get_map_1(self, BuildMetadata, GetMap): 
        self.alchemist.metadata = "Metadata"

        self.alchemist.get_map("Test")

        BuildMetadata.assert_not_called()
        GetMap.assert_called_with("Metadata", "Test")

    @patch("pdm_utils.classes.alchemyhandler.cartography.get_map")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_get_map_2(self, BuildMetadata, GetMap):
        self.alchemist.metadata = None

        self.alchemist.get_map("Test")

        BuildMetadata.assert_called()
        GetMap.assert_called_with(None, "Test") 
コード例 #14
0
ファイル: test_review.py プロジェクト: stjacqrm/pdm_utils
class TestPhamReview(unittest.TestCase):
    @classmethod
    def setUpClass(self):
        test_db_utils.create_filled_test_db()

        self.test_dir = Path(TEST_DIR)
        if self.test_dir.is_dir():
            shutil.rmtree(TEST_DIR)

        self.test_dir.mkdir()

    @classmethod
    def tearDownClass(self):
        test_db_utils.remove_db()
        shutil.rmtree(TEST_DIR)

    def setUp(self):
        self.review_test_dir = self.test_dir.joinpath("review_test_dir")

        self.alchemist = AlchemyHandler()
        self.alchemist.username = USER
        self.alchemist.password = PWD
        self.alchemist.database = DB
        self.alchemist.connect(ask_database=True, login_attempts=0)

        self.db_filter = Filter(alchemist=self.alchemist)
        self.db_filter.add(review.BASE_CONDITIONALS)
        self.db_filter.key = "gene.PhamID"

    def tearDown(self):
        if self.review_test_dir.is_dir():
            shutil.rmtree(str(self.review_test_dir))

    def test_execute_review_1(self):
        """Verify execute_review() creates new directory as expected.
        """
        review.execute_review(self.alchemist, self.test_dir,
                              self.review_test_dir.name)

        self.assertTrue(self.review_test_dir.is_dir())

    def test_execute_review_2(self):
        """Verify execute_review() filter parameter functions as expected.
        """
        review.execute_review(self.alchemist,
                              self.test_dir,
                              self.review_test_dir.name,
                              filters=("phage.Cluster='A' "
                                       "AND phage.Subcluster='A2'"))

        self.assertTrue(self.review_test_dir.is_dir())

    def test_execute_review_3(self):
        """Verify execute_review() group parameter functions as expected.
        """
        review.execute_review(self.alchemist,
                              self.test_dir,
                              self.review_test_dir.name,
                              groups=["phage.Cluster"])

        self.assertTrue(self.review_test_dir.is_dir())

        clusterA_dir = self.review_test_dir.joinpath("A")
        self.assertTrue(clusterA_dir.is_dir())

    def test_execute_review_4(self):
        """Verify execute_review() sort parameter functions as expected.
        """
        review.execute_review(self.alchemist,
                              self.test_dir,
                              self.review_test_dir.name,
                              sort=["gene.Name"])

        self.assertTrue(self.review_test_dir.is_dir())

    def test_execute_review_5(self):
        """Verify execute_review() review parameter functions as expected.
        """
        review.execute_review(self.alchemist,
                              self.test_dir,
                              self.review_test_dir.name,
                              review=False)

        self.assertTrue(self.review_test_dir.is_dir())

    def test_execute_review_6(self):
        """Verify execute_review() g_reports parameter functions as expected.
        """
        review.execute_review(self.alchemist,
                              self.test_dir,
                              self.review_test_dir.name,
                              g_reports=True)

        self.assertTrue(self.review_test_dir.is_dir())

        gene_report_dir = self.review_test_dir.joinpath("GeneReports")
        self.assertTrue(gene_report_dir.is_dir())

    def test_execute_review_7(self):
        """Verify execute_review() s_report parameter functions as expected.
        """
        review.execute_review(self.alchemist,
                              self.test_dir,
                              self.review_test_dir.name,
                              s_report=True)

        self.assertTrue(self.review_test_dir.is_dir())

        summary_report_file = self.review_test_dir.joinpath(
            "SummaryReport.txt")
        self.assertTrue(summary_report_file.is_file())

    def test_review_phams_1(self):
        """Verify review_phams() correctly identifies disrepencies.
        """
        self.db_filter.values = self.db_filter.build_values(
            where=self.db_filter.build_where_clauses())

        review.review_phams(self.db_filter)

        self.assertFalse(39854 in self.db_filter.values)
        self.assertTrue(40481 in self.db_filter.values)

    def test_get_pf_data_1(self):
        """Verify get_pf_data() retrieves and returns data as expected.
        """
        self.db_filter.values = [40481]

        pf_data = review.get_pf_data(self.alchemist, self.db_filter)

        self.assertTrue(isinstance(pf_data, list))

        for header in review.PF_HEADER:
            with self.subTest(header=header):
                self.assertTrue(header in pf_data[0].keys())
                self.assertFalse(isinstance(pf_data[0][header], list))

    def test_get_g_data_1(self):
        """Verify get_g_data() retreives and retrusn data as expected.
        """
        self.db_filter.values = [40481]

        g_data = review.get_g_data(self.alchemist, self.db_filter, 40481)

        self.assertTrue(isinstance(g_data, list))

        for header in review.PG_HEADER:
            with self.subTest(header=header):
                self.assertTrue(header in g_data[0].keys())
                self.assertFalse(isinstance(g_data[0][header], list))
コード例 #15
0
ファイル: test_revise.py プロジェクト: cdshaffer/pdm_utils
class TestGenbankRevise(unittest.TestCase):
    @classmethod
    def setUpClass(self):
        test_db_utils.create_filled_test_db()

        self.test_dir = Path(TEST_DIR)
        if self.test_dir.is_dir():
            shutil.rmtree(TEST_DIR)

        self.test_dir.mkdir()
        self.revise_form = self.test_dir.joinpath("revise_form.txt")

        fileio.export_data_dict(TEST_FR_DATA,
                                self.revise_form,
                                REVIEW_HEADER,
                                include_headers=True)

    @classmethod
    def tearDownClass(self):
        test_db_utils.remove_db()
        shutil.rmtree(TEST_DIR)

    def setUp(self):
        self.alchemist = AlchemyHandler()
        self.alchemist.username = USER
        self.alchemist.password = PWD
        self.alchemist.database = DB
        self.alchemist.connect(ask_database=True, login_attempts=0)

        self.revise_test_dir = self.test_dir.joinpath("revise_test_dir")
        self.fr_input_file_path = self.test_dir.joinpath("FunctionReport.csv")
        self.csv_input_file_path = self.revise_test_dir.joinpath("gene.csv")

        fileio.export_data_dict(TEST_FR_DATA,
                                self.fr_input_file_path,
                                REVIEW_HEADER,
                                include_headers=True)

        self.assertTrue(self.fr_input_file_path.is_file())

    def tearDown(self):
        if self.revise_test_dir.is_dir():
            shutil.rmtree(str(self.revise_test_dir))

    def test_execute_local_revise_1(self):
        """Verify execute_local_revise() creates new directory as expected.
        """
        revise.execute_local_revise(self.alchemist, self.fr_input_file_path,
                                    self.test_dir, self.revise_test_dir.name)

        self.assertTrue(self.revise_test_dir.is_dir())

    def test_execute_local_revise_2(self):
        """Verify execute_local_revise() filters parameter functions as expected.
        """
        revise.execute_local_revise(self.alchemist,
                                    self.fr_input_file_path,
                                    self.test_dir,
                                    self.revise_test_dir.name,
                                    filters="phage.Cluster=A")

        self.assertTrue(self.revise_test_dir.is_dir())

    def test_execute_local_revise_3(self):
        """Verify execute_local_revise() group parameter functions as expected.
        """
        revise.execute_local_revise(self.alchemist,
                                    self.fr_input_file_path,
                                    self.test_dir,
                                    self.revise_test_dir.name,
                                    groups=["phage.Cluster"])

        self.assertTrue(self.revise_test_dir.is_dir())

        cluster_A_dir = self.revise_test_dir.joinpath("A")
        no_cluster_dir = self.revise_test_dir.joinpath("None")

        cluster_N_dir = self.revise_test_dir.joinpath("N")

        cluster_directories = [cluster_A_dir, no_cluster_dir]
        for dir_path in cluster_directories:
            with self.subTest(cluster=dir_path.name):
                self.assertTrue(dir_path.is_dir())

        with self.subTest(cluster="N"):
            self.assertFalse(cluster_N_dir.is_dir())

    def test_execute_local_revise_4(self):
        """Verify execute_local_revise() removes directory lacking needed revisions.
        """
        revise.execute_local_revise(self.alchemist,
                                    self.fr_input_file_path,
                                    self.test_dir,
                                    self.revise_test_dir.name,
                                    filters="phage.Cluster=N")

        self.assertFalse(self.revise_test_dir.is_dir())

    def test_execute_local_revise_5(self):
        """Verify execute_local_revise() exports expected data.
        """
        revise.execute_local_revise(self.alchemist,
                                    self.fr_input_file_path,
                                    self.test_dir,
                                    self.revise_test_dir.name,
                                    groups=["gene.PhamID"])

        self.assertTrue(self.revise_test_dir.is_dir())

        pham_40481_dir = self.revise_test_dir.joinpath("40481")
        pham_25050_dir = self.revise_test_dir.joinpath("25050")
        pham_40880_dir = self.revise_test_dir.joinpath("40880")
        pham_39529_dir = self.revise_test_dir.joinpath("39529")

        pham_directories = [
            pham_40481_dir, pham_25050_dir, pham_40880_dir, pham_39529_dir
        ]

        for dir_path in pham_directories:
            with self.subTest(cluster=dir_path.name):
                self.assertTrue(dir_path.is_dir())

        with self.subTest(cluster=40481):
            pham_40481_file = pham_40481_dir.joinpath("revise.csv")
            data_dicts = fileio.retrieve_data_dict(pham_40481_file)

            phages = []
            functions = []
            for data_dict in data_dicts:
                phages.append(data_dict["Phage"])
                functions.append(data_dict["Product"])

            self.assertTrue("D29" in phages)
            self.assertTrue("L5" in phages)
            self.assertFalse("Et2Brutus" in phages)
            self.assertFalse("Trixie" in phages)

            for function in functions:
                self.assertEqual(function, "terminase")

        with self.subTest(cluster=25050):
            pham_25050_file = pham_25050_dir.joinpath("revise.csv")
            data_dicts = fileio.retrieve_data_dict(pham_25050_file)

            phages = []
            functions = []
            for data_dict in data_dicts:
                phages.append(data_dict["Phage"])
                functions.append(data_dict["Product"])

            self.assertTrue("Et2Brutus" in phages)
            self.assertTrue("Sparky" in phages)
            self.assertFalse("MichelleMyBell" in phages)

            for function in functions:
                self.assertEqual(function, "minor tail protein")

        with self.subTest(cluster=40880):
            pham_40880_file = pham_40880_dir.joinpath("revise.csv")
            data_dicts = fileio.retrieve_data_dict(pham_40880_file)

            phages = []
            functions = []
            for data_dict in data_dicts:
                phages.append(data_dict["Phage"])
                functions.append(data_dict["Product"])

            self.assertTrue("D29" in phages)
            self.assertTrue("L5" in phages)
            self.assertTrue("Et2Brutus" in phages)
            self.assertFalse("MichelleMyBell" in phages)

            for function in functions:
                self.assertEqual(function, "holin")

        with self.subTest(cluster=39529):
            pham_39529_file = pham_39529_dir.joinpath("revise.csv")
            data_dicts = fileio.retrieve_data_dict(pham_39529_file)

            phages = []
            functions = []
            for data_dict in data_dicts:
                phages.append(data_dict["Phage"])
                functions.append(data_dict["Product"])

            self.assertTrue("D29" in phages)
            self.assertTrue("L5" in phages)
            self.assertTrue("Et2Brutus" in phages)
            self.assertTrue("Trixie" in phages)
            self.assertFalse("Yvonnetastic" in phages)

            for function in functions:
                self.assertEqual(function, "endonuclease VII")
コード例 #16
0
ファイル: test_fileio.py プロジェクト: cdshaffer/pdm_utils
class TestFeatureTableParser(unittest.TestCase):
    @classmethod
    def setUpClass(self):
        base_dir = Path(TMPDIR_BASE)
        self.test_dir = base_dir.joinpath(TMPDIR_PREFIX)
        test_db_utils.create_filled_test_db()

        if self.test_dir.is_dir():
            shutil.rmtree(self.test_dir)

        self.test_dir.mkdir()

        self.alchemist = AlchemyHandler()
        self.alchemist.username = USER
        self.alchemist.password = PWD
        self.alchemist.database = DB
        self.alchemist.connect(ask_database=True, login_attempts=0)
        self.acc_id_dict = get_acc_id_dict(self.alchemist)

        accession_list = list(self.acc_id_dict.keys())
        ncbi_handle = Entrez.efetch(db="nucleotide",
                                    rettype="ft",
                                    id=",".join(accession_list),
                                    retmode="text")

        copy_gb_ft_files(ncbi_handle, self.acc_id_dict, self.test_dir)

    @classmethod
    def tearDownClass(self):
        test_db_utils.remove_db()
        shutil.rmtree(self.test_dir)

    def setUp(self):
        self.indv_test_dir = self.test_dir.joinpath("RewrittenFeatureTables")
        self.indv_test_dir.mkdir()

    def tearDown(self):
        shutil.rmtree(self.indv_test_dir)

    def test_read_feature_table_1(self):
        """Verify read_feature_table() can read a number of differing tbl files.
        """
        for tbl_file in self.test_dir.iterdir():
            if tbl_file.is_file():
                file_name = tbl_file.name
                if file_name == "TOTAL.tbl":
                    continue

                with self.subTest(file_name=file_name):
                    with tbl_file.open(mode="r") as filehandle:
                        record = fileio.read_feature_table(filehandle)

                        self.assertFalse(record is None)
                        self.assertTrue(record.id in self.acc_id_dict.keys())

                        record_name = self.acc_id_dict[record.id][0]
                        self.assertEqual(".".join([record_name, "tbl"]),
                                         file_name)

    def test_read_feature_table_2(self):
        """Verify read_feature_table() returns None from an empty file.
        """
        empty_file = self.indv_test_dir.joinpath("empty.tbl")
        with empty_file.open(mode="w") as filehandle:
            filehandle.write("Nonsense\n")

        with empty_file.open(mode="r") as filehandle:
            record = fileio.read_feature_table(filehandle)

        self.assertTrue(record is None)

    def test_parse_feature_table_1(self):
        """Verify parse_feature_table() can read a concatenated feature table.
        """
        total_file_path = self.test_dir.joinpath("TOTAL.tbl")
        with total_file_path.open(mode="r") as filehandle:
            records = fileio.parse_feature_table(filehandle)

            for record in records:
                with self.subTest(accession=record.id):
                    self.assertFalse(record is None)
                    self.assertTrue(record.id in self.acc_id_dict.keys())

                    record_name = self.acc_id_dict[record.id][0]
                    record_path = self.test_dir.joinpath(".".join(
                        [record_name, "tbl"]))

                    self.assertTrue(record_path.is_file())

    def test_parse_write_feature_table_1(self):
        """Verify reading the feature table in and rewriting it exactly mimics
        copied feature table files.
        """
        total_file_path = self.test_dir.joinpath("TOTAL.tbl")

        records = []
        file_names = []
        with total_file_path.open(mode="r") as filehandle:
            parser = fileio.parse_feature_table(filehandle)
            for record in parser:
                record_name = self.acc_id_dict[record.id][0]
                record.name = record_name

                file_names.append(".".join([record_name, "tbl"]))

                records.append(record)

        fileio.write_feature_table(records, self.indv_test_dir)

        file_diffs = filecmp.cmpfiles(self.test_dir, self.indv_test_dir,
                                      file_names)

        for discrepant_file in file_diffs[1]:
            print(f"TEST ERROR: data in {discrepant_file} is incorrect:")
            source_file = self.test_dir.joinpath(discrepant_file)
            rewritten_file = self.indv_test_dir.joinpath(discrepant_file)

            source_handle = source_file.open(mode="r")
            rewritten_handle = rewritten_file.open(mode="r")

            source_data = source_handle.readlines()
            rewritten_data = rewritten_handle.readlines()

            source_handle.close()
            rewritten_handle.close()

            for i in range(len(source_data)):
                source_line = source_data[i]
                print(source_line.rstrip("\n"))
                self.assertTrue(i <= len(rewritten_data) - 1)

                rewritten_line = rewritten_data[i]

                self.assertEqual(source_line, rewritten_line)

        self.assertTrue(len(file_diffs[1]) == 0)
コード例 #17
0
ファイル: find_domains.py プロジェクト: stjacqrm/pdm_utils
def main(argument_list):
    """
    :param argument_list:
    :return:
    """
    # Setup argument parser
    cdd_parser = setup_argparser()

    # Use argument parser to parse argument_list
    args = cdd_parser.parse_args(argument_list)

    # Store arguments in more easily accessible variables
    database = args.database
    cdd_dir = expand_path(args.cdd)
    cdd_name = learn_cdd_name(cdd_dir)
    threads = args.threads
    evalue = args.evalue
    rpsblast = args.rpsblast
    tmp_dir = args.tmp_dir
    output_folder = args.output_folder
    reset = args.reset
    batch_size = args.batch_size

    # Set up directory.
    output_folder = basic.set_path(output_folder, kind="dir", expect=True)
    results_folder = pathlib.Path(RESULTS_FOLDER)
    results_path = basic.make_new_dir(output_folder, results_folder,
                                      attempt=50)
    if results_path is None:
        print("Unable to create output_folder.")
        sys.exit(1)

    log_file = pathlib.Path(results_path, MAIN_LOG_FILE)

    # Set up root logger.
    logging.basicConfig(filename=log_file, filemode="w", level=logging.DEBUG,
                        format="pdm_utils find_domains: %(levelname)s: %(message)s")
    logger.info(f"pdm_utils version: {VERSION}")
    logger.info(f"CDD run date: {constants.CURRENT_DATE}")
    logger.info(f"Command line arguments: {' '.join(argument_list)}")
    logger.info(f"Results directory: {results_path}")

    # Early exit if either 1) cdd_name == "" or 2) no rpsblast given and we are
    # unable to find one
    if cdd_name == "":
        msg = (f"Unable to learn CDD database name. Make sure the files in "
              f"{cdd_dir} all have the same basename.")
        logger.error(msg)
        print(msg)
        return

    # Get the rpsblast command and path.
    if rpsblast == "":
        command = get_rpsblast_command()
        rpsblast = get_rpsblast_path(command)

    # Verify database connection and schema compatibility.
    alchemist = AlchemyHandler(database=database)
    alchemist.connect(pipeline=True)
    engine = alchemist.engine
    logger.info(f"Connected to database: {database}.")
    mysqldb.check_schema_compatibility(engine, "the find_domains pipeline")
    logger.info(f"Schema version is compatible.")
    logger.info("Command line arguments verified.")

    if reset:
        logger.info("Clearing all domain data currently in the database.")
        clear_domain_data(engine)

    # Get gene data that needs to be processed
    # in dict format where key = column name, value = stored value.
    cdd_genes = mysqldb_basic.query_dict_list(engine, GET_GENES_FOR_CDD)
    msg = f"{len(cdd_genes)} genes to search for conserved domains..."
    logger.info(msg)
    print(msg)

    # Only run the pipeline if there are genes returned that need it
    if len(cdd_genes) > 0:

        log_gene_ids(cdd_genes)
        make_tempdir(tmp_dir)

        # Identify unique translations to process mapped to GeneIDs.
        cds_trans_dict = create_cds_translation_dict(cdd_genes)

        unique_trans = list(cds_trans_dict.keys())
        msg = (f"{len(unique_trans)} unique translations "
               "to search for conserved domains...")
        logger.info(msg)
        print(msg)

        # Process translations in batches. Otherwise, searching could take
        # so long that MySQL connection closes resulting in 1 or more
        # transaction errors.
        batch_indices = basic.create_indices(unique_trans, batch_size)
        total_rolled_back = 0
        for indices in batch_indices:
            start = indices[0]
            stop = indices[1]
            msg = f"Processing translations {start + 1} to {stop}..."
            logger.info(msg)
            print(msg)
            sublist = unique_trans[start:stop]
            batch_rolled_back = search_translations(
                                    rpsblast, cdd_name, tmp_dir, evalue,
                                    threads, engine, sublist, cds_trans_dict)
            total_rolled_back += batch_rolled_back

        search_summary(total_rolled_back)
        engine.dispose()

    return
コード例 #18
0
class TestAnnotationRetrieval(unittest.TestCase):
    @classmethod
    def setUpClass(self):
        test_db_utils.create_filled_test_db()

    @classmethod
    def tearDownClass(self):
        test_db_utils.remove_db()

    def setUp(self):
        self.alchemist = AlchemyHandler()
        self.alchemist.username = test_db_utils.USER
        self.alchemist.password = test_db_utils.PWD
        self.alchemist.database = test_db_utils.DB
        self.alchemist.connect(ask_database=True, login_attempts=0)

    def test_get_relative_gene_1(self):
        """Verify get_relative_gene() returns GeneID string as expected."""

        rel_geneid = annotation.get_relative_gene(self.alchemist,
                                                  "Trixie_CDS_2", -1)

        self.assertEqual(rel_geneid, "Trixie_CDS_1")

    def test_get_relative_gene_2(self):
        """Verify get_relative_gene() returns None when expected."""

        rel_geneid = annotation.get_relative_gene(self.alchemist,
                                                  "Trixie_CDS_1", -1)

        self.assertEqual(rel_geneid, None)

    def test_get_relative_gene_3(self):
        """Verify get_relative_gene() raises ValueError from bad GeneID."""
        with self.assertRaises(ValueError):
            annotation.get_relative_gene(self.alchemist, "NOT A GENE", 8675309)

    def test_get_adjacent_genes_1(self):
        """Verify get_adjacent_phams() returns get_relative_gene() results."""
        adjacent_genes = annotation.get_adjacent_genes(self.alchemist,
                                                       "Trixie_CDS_2")

        self.assertEqual(adjacent_genes[0], "Trixie_CDS_1")
        self.assertEqual(adjacent_genes[1], "Trixie_CDS_3")

    def test_get_adjacent_phams_1(self):
        """Verify get_adjacent_phams() returns expected data type."""

        adjacent_phams = annotation.get_distinct_adjacent_phams(
            self.alchemist, 42006)

        self.assertTrue(isinstance(adjacent_phams, tuple))
        self.assertTrue(len(adjacent_phams) == 2)

        self.assertTrue(isinstance(adjacent_phams[0], list))
        self.assertTrue(isinstance(adjacent_phams[1], list))

        for left_pham in adjacent_phams[0]:
            with self.subTest(pham=left_pham):
                self.assertTrue(isinstance(left_pham, int))

        for right_pham in adjacent_phams[1]:
            with self.subTest(pham=right_pham):
                self.assertTrue(isinstance(right_pham, int))

    def test_get_count_annotations_in_pham_1(self):
        """Verify get_count_annotations_in_pham() returns expected data type."""

        annotation_counts = annotation.get_count_annotations_in_pham(
            self.alchemist, 42006)

        self.assertTrue(isinstance(annotation_counts, dict))

        for key in annotation_counts.keys():
            with self.subTest(annotation=key):
                self.assertTrue(isinstance(annotation_counts, str))

                self.assertTrue(isinstance(annotation_counts[key], int))
コード例 #19
0
ファイル: get_data.py プロジェクト: cdshaffer/pdm_utils
def main(unparsed_args_list):
    """Run main retrieve_updates pipeline."""
    # Parse command line arguments
    args = parse_args(unparsed_args_list)
    force = args.force_download
    args.output_folder = basic.set_path(args.output_folder,
                                        kind="dir",
                                        expect=True)
    working_dir = pathlib.Path(RESULTS_FOLDER)
    working_path = basic.make_new_dir(args.output_folder,
                                      working_dir,
                                      attempt=50)

    if working_path is None:
        print(f"Invalid working directory '{working_dir}'")
        sys.exit(1)

    # Create config object with data obtained from file and/or defaults.
    config = configfile.build_complete_config(args.config_file)
    mysql_creds = config["mysql"]
    ncbi_creds = config["ncbi"]

    # Verify database connection and schema compatibility.
    print("Preparing genome data sets from the MySQL database...")
    alchemist = AlchemyHandler(database=args.database,
                               username=mysql_creds["user"],
                               password=mysql_creds["password"])
    alchemist.connect(pipeline=True)
    engine = alchemist.engine
    mysqldb.check_schema_compatibility(engine, "the get_data pipeline")

    # Get existing data from MySQL to determine what needs to be updated.
    query = ("SELECT PhageID, Name, HostGenus, Status, Cluster, "
             "DateLastModified, Accession, RetrieveRecord, Subcluster, "
             "AnnotationAuthor FROM phage")

    mysqldb_genome_list = mysqldb.parse_genome_data(engine=engine,
                                                    phage_query=query,
                                                    gnm_type="mysqldb")
    engine.dispose()
    mysqldb_genome_dict = {}
    for gnm in mysqldb_genome_list:
        # With default date, the date of all records retrieved will be newer.
        if force:
            gnm.date = constants.EMPTY_DATE
        mysqldb_genome_dict[gnm.id] = gnm

    # Get data from PhagesDB
    if (args.updates or args.final or args.draft) is True:
        print("Retrieving data from PhagesDB...")
        phagesdb_phages = phagesdb.get_phagesdb_data(constants.API_SEQUENCED)
        phagesdb_phages_dict = basic.convert_list_to_dict(
            phagesdb_phages, "phage_name")
        phagesdb_genome_dict = phagesdb.parse_genomes_dict(
            phagesdb_phages_dict, gnm_type="phagesdb", seq=False)

        # Exit if all phage data wasn't retrieved.
        if len(phagesdb_genome_dict) == 0:
            sys.exit(1)

        # Returns a list of tuples.
        tup = match_genomes(mysqldb_genome_dict, phagesdb_genome_dict)
        matched_genomes = tup[0]
        unmatched_phagesdb_ids = tup[1]

    if args.updates is True:
        get_update_data(working_path, matched_genomes)
    if args.final is True:
        get_final_data(working_path, matched_genomes)
    if args.genbank is True:
        get_genbank_data(working_path,
                         mysqldb_genome_dict,
                         ncbi_creds,
                         args.genbank_results,
                         force=force)
    if args.draft is True:
        if force:
            # Add all draft genomes currently in database to the list of
            # draft genomes to be downloaded.
            drafts = get_matched_drafts(matched_genomes)
            unmatched_phagesdb_ids |= drafts
        get_draft_data(working_path, unmatched_phagesdb_ids)
コード例 #20
0
def main(unparsed_args_list):
    """Run main conversion pipeline."""
    # Parse command line arguments
    args = parse_args(unparsed_args_list)
    config = configfile.build_complete_config(args.config_file)
    mysql_creds = config["mysql"]
    alchemist1 = AlchemyHandler(database=args.database,
                                username=mysql_creds["user"],
                                password=mysql_creds["password"])
    alchemist1.connect(pipeline=True)
    engine1 = alchemist1.engine


    target = args.schema_version
    actual = mysqldb.get_schema_version(engine1)
    steps, dir = get_conversion_direction(actual, target)

    # Iterate through list of versions and implement SQL files.
    if dir == "none":
        if args.verbose == True:
            print("No schema conversion is needed.")
        convert = False
    else:
        convert = True

    if convert == True:
        if (args.new_database_name is not None and
                args.new_database_name != args.database):
            result = mysqldb_basic.drop_create_db(engine1, args.new_database_name)
            if result == 0:
                result = mysqldb_basic.copy_db(engine1, args.new_database_name)
                if result == 0:
                    # Create a new connection to the new database.
                    alchemist2 = AlchemyHandler(database=args.new_database_name,
                                                username=engine1.url.username,
                                                password=engine1.url.password)
                    alchemist2.connect(pipeline=True)
                    engine2 = alchemist2.engine

                else:
                    print("Error: Unable to copy the database for conversion.")
                    convert = False
            else:
                print("Error: Unable to create the new database for conversion.")
                convert = False
        else:
            engine2 = engine1

        if convert == True:
            stop_step, summary = convert_schema(engine2, actual, dir,
                                                steps, verbose=args.verbose)
            engine2.dispose()
            if stop_step == target:
                if args.verbose == True:
                    print("\n\nThe database schema conversion was successful.")
            else:
                print("\n\nError: "
                      "The database schema conversion was not successful. "
                      f"Unable to proceed past schema version {stop_step}.")
            if args.verbose == True:
                print_summary(summary)
    engine1.dispose()
コード例 #21
0
ファイル: phamerate.py プロジェクト: cdshaffer/pdm_utils
def main(argument_list):
    # Set up the argument parser
    parser = setup_argparser()

    # Parse arguments
    args = vars(parser.parse_args(argument_list[2:]))

    # Temporary directory gets its own variable because we'll use it a lot
    tmp = args["tmp_dir"]

    # Create config object with data obtained from file and/or defaults.
    config = build_complete_config(args["config_file"])
    mysql_creds = config["mysql"]

    # Make a note of which workflow we're using based on len(args)
    if len(args) > 10:
        program = "mmseqs"
    else:
        program = "blast-mcl"

    # Record start time
    start_time = datetime.now()

    # Initialize SQLAlchemy engine with database provided at CLI
    alchemist = AlchemyHandler(database=args["db"],
                               username=mysql_creds["user"],
                               password=mysql_creds["password"])
    alchemist.connect(login_attempts=5, pipeline=True)
    engine = alchemist.engine

    # Refresh temp_dir
    refresh_tempdir(tmp)

    # Get old pham data and un-phamerated genes
    old_phams = get_pham_geneids(engine)
    old_colors = get_pham_colors(engine)
    new_genes = get_new_geneids(engine)

    # Get GeneIDs & translations, and translation groups
    # gene_x: translation_x
    genes_and_translations = get_geneids_and_translations(engine)
    # translation_x: [gene_x, ..., gene_z]
    translation_groups = get_translation_groups(engine)

    # Print initial state
    initial_summary = f"""
Initial database summary:
=============================
 {len(old_phams)} total phams
 {sum([len(x) == 1 for x in old_phams.values()])} orphams
 {sum([len(x) for x in old_phams.values()])} genes in phams
 {len(new_genes)} genes not in phams
 {len(genes_and_translations)} total genes
 {len(translation_groups)} non-redundant genes
=============================
"""
    print(initial_summary)

    # Write input fasta file
    print("Writing non-redundant sequences to input fasta...")
    infile = f"{tmp}/input.fasta"
    write_fasta(translation_groups, infile)

    # Here is where the workflow selection comes into play
    if program == "mmseqs":
        seq_db = f"{tmp}/sequenceDB"            # MMseqs2 sequence database
        clu_db = f"{tmp}/clusterDB"             # MMseqs2 cluster database
        psf_db = f"{tmp}/seqfileDB"             # pre-pham seqfile database
        p_out = f"{tmp}/pre_out.fasta"          # pre-pham output (FASTA)

        print("Creating MMseqs2 sequence database...")
        mmseqs_createdb(infile, seq_db)

        print("Clustering sequence database...")
        mmseqs_cluster(seq_db, clu_db, args)

        print("Storing sequence-based phamilies...")
        mmseqs_createseqfiledb(seq_db, clu_db, psf_db)
        mmseqs_result2flat(seq_db, seq_db, psf_db, p_out)
        pre_phams = parse_mmseqs_output(p_out)      # Parse pre-pham output

        # Proceed with profile clustering, if allowed
        if not args["skip_hmm"]:
            con_lookup = dict()
            for name, geneids in pre_phams.items():
                for geneid in geneids:
                    con_lookup[geneid] = name

            pro_db = f"{tmp}/profileDB"         # MMseqs2 profile database
            con_db = f"{tmp}/consensusDB"       # Consensus sequence database
            aln_db = f"{tmp}/alignDB"           # Alignment database
            res_db = f"{tmp}/resultDB"          # Cluster database
            hsf_db = f"{tmp}/hmmSeqfileDB"      # hmm-pham seqfile database
            h_out = f"{tmp}/hmm_out.fasta"      # hmm-pham output (FASTA)

            print("Creating HMM profiles from sequence-based phamilies...")
            mmseqs_result2profile(seq_db, clu_db, pro_db)

            print("Extracting consensus sequences from HMM profiles...")
            mmseqs_profile2consensus(pro_db, con_db)

            print("Searching for profile-profile hits...")
            mmseqs_search(pro_db, con_db, aln_db, args)

            print("Clustering based on profile-profile alignments...")
            mmseqs_clust(con_db, aln_db, res_db)

            print("Storing profile-based phamilies...")
            mmseqs_createseqfiledb(seq_db, res_db, hsf_db)
            mmseqs_result2flat(pro_db, con_db, hsf_db, h_out)
            hmm_phams = parse_mmseqs_output(h_out)

            print("Merging sequence and profile-based phamilies...")
            new_phams = merge_pre_and_hmm_phams(
                hmm_phams, pre_phams, con_lookup)
        else:
            new_phams = pre_phams
    else:
        blast_db = "blastdb"
        blast_path = f"{tmp}/{blast_db}"

        print("Creating blast protein database...")
        create_blastdb(infile, blast_db, blast_path)

        print("Splitting non-redundant sequences into multiple blastp query "
              "files...")
        chunks = chunk_translations(translation_groups)

        jobs = []
        for key, chunk in chunks.items():
            jobs.append((key, chunk, tmp, blast_path, 
                         args["e_value"], args["query_cov"]))

        print("Running blastp...")
        parallelize(jobs, args["threads"], blastp)

        print("Converting blastp output into adjacency matrix for mcl...")
        results = [x for x in os.listdir(tmp) if x.endswith(".tsv")]
        adjacency = f"{tmp}/blast_adjacency.abc"
        with open(adjacency, "w") as fh:
            for result in results:
                f = open(f"{tmp}/{result}", "r")
                for line in f:
                    fh.write(line)
                f.close()

        print("Running mcl on adjacency matrix...")
        outfile = markov_cluster(adjacency, args["inflate"], tmp)

        print("Storing blast-mcl phamilies...")
        new_phams = parse_mcl_output(outfile)

        # Some proteins don't have even self-hits in blastp - take a
        # census of who is missing, and add them as "orphams"
        mcl_genes = set()
        for name, pham in new_phams.items():
            for gene in pham:
                mcl_genes.add(genes_and_translations[gene])

        all_trans = set(translation_groups.keys())

        # Some genes don't have blast hits, even to themselves. These are
        # not in the blast output and need to be re-inserted as orphams.
        missing = all_trans - mcl_genes
        for translation in missing:
            new_phams[len(new_phams) + 1] = \
                [translation_groups[translation][0]]

    # Reintroduce duplicates
    print("Propagating phamily assignments to duplicate genes...")
    new_phams = reintroduce_duplicates(new_phams, translation_groups,
                                       genes_and_translations)

    # Preserve old pham names and colors
    print("Preserving old phamily names/colors where possible...")
    new_phams, new_colors = preserve_phams(old_phams, new_phams,
                                           old_colors, new_genes)

    # Early exit if we don't have new phams or new colors - avoids
    # overwriting the existing pham data with incomplete new data
    if len(new_phams) == 0 or len(new_colors) == 0:
        print("Failed to parse new pham/color data properly... Terminating "
              "pipeline")
        return

    # Update gene/pham tables with new pham data. Pham colors need to be done
    # first, because gene.PhamID is a foreign key to pham.PhamID.
    print("Updating pham data in database...")
    update_pham_table(new_colors, engine)
    update_gene_table(new_phams, engine)

    # Fix miscolored phams/orphams
    print("Phixing phalsely phlagged orphams...", end=" ")
    fix_white_phams(engine)
    print("Phixing phalsely hued phams...", end=" ")
    fix_colored_orphams(engine)

    # Close all connections in the connection pool.
    engine.dispose()

    # Print final state
    final_summary = f"""
Final database summary:
=============================
 {len(new_phams)} total phams
 {sum([len(x) == 1 for x in new_phams.values()])} orphams
 {sum([len(x) for x in new_phams.values()])} genes in phams
 {len(genes_and_translations) - sum([len(x) for x in new_phams.values()])} genes not in phams
 {len(genes_and_translations)} total genes
 {len(translation_groups)} non-redundant genes
=============================
"""
    print(final_summary)

    # Record stop time
    stop_time = datetime.now()
    elapsed_time = str(stop_time - start_time)

    # Report phameration elapsed time
    print(f"Elapsed time: {elapsed_time}")
コード例 #22
0
class TestAlchemyHandler(unittest.TestCase):
    def setUp(self):
        self.alchemist = AlchemyHandler()

    def test_constructor_1(self):
        """Verify AlchemyHandler credentials are initialized as None.
        """
        self.assertEqual(self.alchemist._database, None)
        self.assertEqual(self.alchemist._username, None)
        self.assertEqual(self.alchemist._password, None)

    def test_constructor_2(self):
        """Verify AlchemyHandler data objects are initialized as None.
        """
        self.assertEqual(self.alchemist._engine, None)
        self.assertEqual(self.alchemist._metadata, None)
        self.assertEqual(self.alchemist._graph, None)
        self.assertEqual(self.alchemist._session, None)

    def test_constructor_3(self):
        """Verify AlchemyHandler data booleans are initialized as False.
        """
        self.assertFalse(self.alchemist.connected)
        self.assertFalse(self.alchemist.has_database)
        self.assertFalse(self.alchemist.has_credentials)

    def test_database_1(self):
        """Verify database property sets has_database.
        """
        self.alchemist.database = "Test"
        self.assertTrue(self.alchemist.has_database)
        self.assertFalse(self.alchemist.connected_database)

    def test_username_1(self):
        """Verify username property conserves has_credentials and connected.
        """
        self.alchemist.username = "******"
        self.assertFalse(self.alchemist.has_credentials)
        self.assertFalse(self.alchemist.connected)

    def test_username_2(self):
        """Verify username property sets has_credentials with valid password.
        """
        self.alchemist.password = "******"
        self.alchemist.username = "******"
        self.assertTrue(self.alchemist.has_credentials)
        self.assertFalse(self.alchemist.connected)

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.clear")
    def test_username_3(self, clear_mock):
        """Verify changing usrename property calls clear().
        """
        self.alchemist.username = "******"

        clear_mock.assert_called()

    def test_password_1(self):
        """Verify password property conserves has_credentials and connected.
        """
        self.alchemist.password = "******"
        self.assertFalse(self.alchemist.has_credentials)
        self.assertFalse(self.alchemist.connected)

    def test_password_2(self):
        """Verify password property sets has_credentials with valid username.
        """
        self.alchemist.username = "******"
        self.alchemist.password = "******"
        self.assertTrue(self.alchemist.has_credentials)
        self.assertFalse(self.alchemist.connected)

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.clear")
    def test_password_3(self, clear_mock):
        """Verify changing password property calls clear().
        """
        self.alchemist.password = "******"

        clear_mock.assert_called()

    def test_construct_engine_string_1(self):
        """Verify construct_engine_string generates an expected URI.
        """
        URI = self.alchemist.construct_engine_string(username="******",
                                                     password="******")
        self.assertEqual(URI, "mysql+pymysql://pdm_user:pdm_pass@localhost/")

    def test_construct_engine_string_2(self):
        """Verify construct_engine_string accepts use of different drivers.
        """
        URI = self.alchemist.construct_engine_string(driver="mysqlconnector",
                                                     username="******",
                                                     password="******")

        self.assertEqual(URI,
                         "mysql+mysqlconnector://pdm_user:pdm_pass@localhost/")

    def test_engine_1(self):
        """Verify engine property sets connected.
        """
        self.alchemist.connected = True
        self.alchemist.engine = None

        self.assertFalse(self.alchemist.connected)

    def test_engine_2(self):
        """Verify engine property raises TypeError on bad engine input.
        """
        with self.assertRaises(TypeError):
            self.alchemist.engine = "Test"

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    def test_engine_3(self, build_engine_mock):
        """Verify engine property calls build_engine() selectively.
        """
        mock_engine = Mock()
        build_engine_mock.return_value = mock_engine

        self.alchemist._engine = "Test"
        self.assertEqual(self.alchemist.engine, "Test")

        build_engine_mock.assert_not_called()

        self.alchemist._engine = None
        self.alchemist.engine

        build_engine_mock.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler"
           ".extract_engine_credentials")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.get_mysql_dbs")
    def test_engine_4(self, get_mysql_dbs_mock,
                      extract_engine_credentials_mock):
        """Verify call structure of engine property setter.
        """
        mock_engine = Mock(spec=Engine)

        self.alchemist.engine = mock_engine

        get_mysql_dbs_mock.assert_called()
        extract_engine_credentials_mock.assert_called_with(mock_engine)

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_metadata_1(self, build_metadata_mock):
        """Verify metadata property calls build_metadata() selectively.
        """
        self.alchemist._metadata = "Test"
        self.alchemist.metadata

        build_metadata_mock.assert_not_called()

        self.alchemist._metadata = None
        self.alchemist.metadata

        build_metadata_mock.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_graph")
    def test_graph_1(self, build_graph_mock):
        """Verify graph property calls build_graph() selectively.
        """
        self.alchemist._graph = "Test"
        self.alchemist.graph

        build_graph_mock.assert_not_called()

        self.alchemist._graph = None
        self.alchemist.graph

        build_graph_mock.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_session")
    def test_session_1(self, build_session_mock):
        """Verify session property calls build_session() selectively.
        """
        self.alchemist._session = "Test"
        self.alchemist.session

        build_session_mock.assert_not_called()

        self.alchemist._session = None
        self.alchemist.session

        build_session_mock.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_mapper")
    def test_mapper_1(self, build_mapper_mock):
        """Verify mapper property calls build_mapper() selectively.
        """
        self.alchemist._mapper = "Test"
        self.alchemist.mapper

        build_mapper_mock.assert_not_called()

        self.alchemist._mapper = None
        self.alchemist.mapper

        build_mapper_mock.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.input")
    def test_ask_database_1(self, Input):
        """Verify ask_database() calls input().
        """
        self.alchemist.ask_database()
        Input.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.input")
    def test_ask_database_2(self, Input):
        """Verify ask_database() sets has_database.
        """
        self.alchemist.has_database = False
        self.alchemist.connected = True

        self.alchemist.ask_database()

        self.assertTrue(self.alchemist.has_database)
        self.assertFalse(self.alchemist.connected)

    @patch("pdm_utils.classes.alchemyhandler.getpass")
    def test_ask_credentials_1(self, GetPass):
        """Verify ask_credentials() calls getpass().
        """
        self.alchemist.ask_credentials()

        GetPass.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.getpass")
    def test_ask_credentials_2(self, GetPass):
        """Verify ask_credentials() sets has_credentials.
        """
        self.alchemist.has_credentials = False
        self.alchemist.connected = True

        self.alchemist.ask_credentials()

        self.assertTrue(self.alchemist.has_credentials)
        self.assertFalse(self.alchemist.connected)

    def test_validate_database_1(self):
        """Verify function structure of validate_database().
        """
        mock_engine = Mock()
        mock_proxy = Mock()

        mock_engine.execute.return_value = mock_proxy
        mock_proxy.fetchall.return_value = [("pdm_test_db",),
                                            ("Actinobacteriophage",)]

        self.alchemist.connected = True
        self.alchemist.database = "pdm_test_db"
        self.alchemist._engine = mock_engine

        self.alchemist.validate_database()

        mock_engine.execute.assert_called_once()
        mock_proxy.fetchall.assert_called()

    def test_validate_database_2(self):
        """Verify validate_database() raises IndexError without database.
        """
        self.alchemist.connected = True

        with self.assertRaises(AttributeError):
            self.alchemist.validate_database()

    def test_validate_database_3(self):
        """Verify validate_database() raises ValueError from bad database input.
        """
        mock_engine = Mock()
        mock_proxy = Mock()

        mock_engine.execute.return_value = mock_proxy
        mock_proxy.fetchall.return_value = []

        self.alchemist.connected = True
        self.alchemist.database = "test db"
        self.alchemist._engine = mock_engine

        with self.assertRaises(MySQLDatabaseError):
            self.alchemist.validate_database()

        mock_engine.execute.assert_called_once()
        mock_proxy.fetchall.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler."
           "ask_credentials")
    @patch("pdm_utils.classes.alchemyhandler.sqlalchemy.create_engine")
    def test_build_engine_1(self, create_engine_mock, ask_credentials_mock):
        """Verify build_engine() returns if connected already.
        """
        self.alchemist.engine = None
        self.alchemist.connected = True
        self.alchemist.build_engine()

        create_engine_mock.assert_not_called()
        ask_credentials_mock.assert_not_called()

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler."
           "ask_credentials")
    @patch("pdm_utils.classes.alchemyhandler.sqlalchemy.create_engine")
    def test_build_engine_2(self, create_engine_mock, ask_credentials_mock):
        """Verify build_engine() raises attribute error without credentials.
        """
        self.alchemist.username = "******"
        self.alchemist.password = "******"
        self.alchemist.has_credentials = False

        with self.assertRaises(AttributeError):
            self.alchemist.build_engine()

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.validate_database")
    @patch("pdm_utils.classes.alchemyhandler.sqlalchemy.create_engine")
    def test_build_engine_3(self, create_engine_mock, validate_database_mock):
        """Verify build_engine() calls create_engine() with db engine string.
        """
        self.alchemist.username = "******"
        self.alchemist.password = "******"
        self.alchemist.database = "database"

        self.alchemist.build_engine()

        login_string = "mysql+pymysql://user:pass@localhost/"
        db_login_string = "mysql+pymysql://user:pass@localhost/database"

        create_engine_mock.assert_any_call(login_string, echo=False)
        create_engine_mock.assert_any_call(db_login_string, echo=False)

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler."
           "ask_credentials")
    @patch("pdm_utils.classes.alchemyhandler.sqlalchemy.create_engine")
    def test_build_engine_4(self, create_engine_mock, ask_credentials_mock):
        """Verify build_engine() sets has_credentials.
        """
        self.alchemist.has_credentials = True
        self.alchemist.connected = False
        self.alchemist._metadata = "Test"
        self.alchemist._graph = "Test"

        self.alchemist.build_engine()

        self.alchemist.connected = True
        self.assertEqual(self.alchemist._metadata, None)
        self.assertEqual(self.alchemist._graph, None)

    @patch("pdm_utils.classes.alchemyhandler.sqlalchemy.create_engine")
    def test_build_engine_5(self, create_engine_mock):
        """Verify AlchemyHandler echo property controls create_engine()
        parameters.
        """
        self.alchemist.username = "******"
        self.alchemist.password = "******"
        self.alchemist.build_engine()

        login_string = "mysql+pymysql://user:pass@localhost/"

        create_engine_mock.assert_any_call(login_string, echo=False)

        self.alchemist.echo = True
        self.alchemist.connected = False
        self.alchemist.build_engine()

        create_engine_mock.assert_any_call(login_string, echo=True)

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler."
           "ask_credentials")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.ask_database")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    def test_connect_1(self, build_engine_mock, ask_database_mock,
                       AskCredentials):
        """Verify connect() returns if build_engine() does not complain.
        """
        self.alchemist.has_credentials = True
        self.alchemist.connected = True
        self.alchemist.connect()
        build_engine_mock.assert_called()
        ask_database_mock.assert_not_called()
        AskCredentials.assert_not_called()

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler."
           "ask_credentials")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.ask_database")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    def test_connect_2(self, build_engine_mock, ask_database_mock,
                       AskCredentials):
        """Verify connect() AlchemyHandler properties control function calls.
        """
        self.alchemist.connected = True
        self.alchemist.connected_database = True
        self.alchemist.has_credentials = True
        self.alchemist.connect(ask_database=True)
        build_engine_mock.assert_called()
        ask_database_mock.assert_not_called()
        AskCredentials.assert_not_called()

    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler."
           "ask_credentials")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.ask_database")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    def test_connect_3(self, build_engine_mock, ask_database_mock,
                       AskCredentials):
        """Verify connect() depends on build_engine() to raise ValueError.
        """
        self.alchemist.connected = False
        build_engine_mock.side_effect = OperationalError("", "", "")

        with self.assertRaises(SQLCredentialsError):
            self.alchemist.connect()

        build_engine_mock.assert_called()
        ask_database_mock.assert_not_called()
        AskCredentials.assert_called()

    def build_engine_side_effect(self, mock_engine):
        """Helper function for side effect usage.
        """
        self.alchemist._engine = mock_engine

    @patch("pdm_utils.classes.alchemyhandler.MetaData")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.ask_database")
    def test_build_metadata_1(self, ask_database_mock, build_engine_mock,
                              metadata_mock):
        """Verify build_metadata() relies on AlchemyHandler properties.
        """
        self.alchemist.has_database = False
        self.alchemist.connected = False

        self.alchemist.build_metadata()

        ask_database_mock.assert_called()
        build_engine_mock.assert_called()
        metadata_mock.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.MetaData")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.ask_database")
    def test_build_metadata_2(self, ask_database_mock, build_engine_mock,
                              metadata_mock):
        """Verify build_metadata() calls ask_database() and build_engine().
        """
        self.alchemist.has_database = True
        self.alchemist.connected = True
        self.alchemist.connected_database = True

        self.alchemist.build_metadata()

        ask_database_mock.assert_not_called()
        build_engine_mock.assert_not_called()
        metadata_mock.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.querying.build_graph")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_build_graph_1(self, build_metadata_mock, build_graph_mock):
        """Verify build_graph() calls querying.build_graph().
        """
        build_graph_mock.return_value = "Graph"

        self.alchemist._metadata = "Metadata"

        self.alchemist.build_graph()

        build_metadata_mock.assert_not_called()
        build_graph_mock.assert_called_with("Metadata")

        self.assertEqual(self.alchemist._graph, "Graph")

    @patch("pdm_utils.classes.alchemyhandler.querying.build_graph")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_build_graph_2(self, build_metadata_mock, build_graph_mock):
        """Verify build_graph() calls build_metadata().
        """
        build_graph_mock.return_value = "Graph"

        self.alchemist._metadata = None

        self.alchemist.build_graph()

        build_metadata_mock.assert_called()
        build_graph_mock.assert_called_with(None)

        self.assertEqual(self.alchemist._graph, "Graph")

    @patch("pdm_utils.classes.alchemyhandler.sessionmaker")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.ask_database")
    def test_build_session_1(self, ask_database_mock, build_engine_mock,
                             sessionmaker_mock):
        """Verify build_session() relies on AlchemyHandler properties.
        """
        self.alchemist.has_database = False
        self.alchemist.connected = False

        self.alchemist.build_session()

        ask_database_mock.assert_called()
        build_engine_mock.assert_called()
        sessionmaker_mock.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.sessionmaker")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_engine")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.ask_database")
    def test_build_session_2(self, ask_database_mock, build_engine_mock,
                             sessionmaker_mock):
        """Verify build_session() calls ask_database() and build_engine().
        """
        self.alchemist.has_database = True
        self.alchemist.connected = True

        self.alchemist.build_session()

        ask_database_mock.assert_not_called()
        build_engine_mock.assert_not_called()
        sessionmaker_mock.assert_called()

    @patch("pdm_utils.classes.alchemyhandler.automap_base")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_build_mapper_1(self, build_metadata_mock, automap_base_mock):
        """Verify build_mapper() calls automap_base().
        """
        base_mock = Mock()
        automap_base_mock.return_value = base_mock

        self.alchemist._metadata = "Metadata"

        self.alchemist.build_mapper()

        build_metadata_mock.assert_not_called()
        automap_base_mock.assert_called_with(metadata="Metadata")

        self.assertEqual(self.alchemist._mapper, base_mock)

    @patch("pdm_utils.classes.alchemyhandler.automap_base")
    @patch("pdm_utils.classes.alchemyhandler.AlchemyHandler.build_metadata")
    def test_build_mapper_2(self, build_metadata_mock, automap_base_mock):
        """Verify build_mapper() calls build_metadata().
        """
        base_mock = Mock()
        automap_base_mock.return_value = base_mock

        self.alchemist._metadata = None

        self.alchemist.build_mapper()

        build_metadata_mock.assert_called()
        automap_base_mock.assert_called_with(metadata=None)

        self.assertEqual(self.alchemist._mapper, base_mock)
コード例 #23
0
ファイル: freeze_db.py プロジェクト: cdshaffer/pdm_utils
def main(unparsed_args_list):
    """Run main freeze database pipeline."""
    args = parse_args(unparsed_args_list)
    ref_database = args.database
    reset = args.reset
    new_database = args.new_database_name
    prefix = args.prefix

    # Filters input: phage.Status=draft AND phage.HostGenus=Mycobacterium
    # Args structure: [['phage.Status=draft'], ['phage.HostGenus=Mycobacterium']]
    filters = args.filters

    # Create config object with data obtained from file and/or defaults.
    config = configfile.build_complete_config(args.config_file)
    mysql_creds = config["mysql"]

    # Verify database connection and schema compatibility.
    print("Connecting to the MySQL database...")
    alchemist1 = AlchemyHandler(database=ref_database,
                                username=mysql_creds["user"],
                                password=mysql_creds["password"])
    alchemist1.connect(pipeline=True)
    engine1 = alchemist1.engine
    mysqldb.check_schema_compatibility(engine1, "the freeze pipeline")

    # Get SQLAlchemy metadata Table object
    # table_obj.primary_key.columns is a
    # SQLAlchemy ColumnCollection iterable object
    # Set primary key = 'phage.PhageID'
    alchemist1.build_metadata()
    table = querying.get_table(alchemist1.metadata, TARGET_TABLE)
    for column in table.primary_key.columns:
        primary_key = column

    # Create filter object and then add command line filter strings
    db_filter = Filter(alchemist=alchemist1, key=primary_key)
    db_filter.values = []

    # Attempt to add filters and exit if needed.
    add_filters(db_filter, filters)

    # Performs the query
    db_filter.update()

    # db_filter.values now contains list of PhageIDs that pass the filters.
    # Get the number of genomes that will be retained and build the
    # MYSQL DELETE statement.
    keep_set = set(db_filter.values)
    delete_stmt = construct_delete_stmt(TARGET_TABLE, primary_key, keep_set)
    count_query = construct_count_query(TARGET_TABLE, primary_key, keep_set)
    phage_count = mysqldb_basic.scalar(alchemist1.engine, count_query)

    # Determine the name of the new database.
    if new_database is None:
        if prefix is None:
            prefix = get_prefix()
        new_database = f"{prefix}_{phage_count}"

    # Create the new database, but prevent overwriting of current database.
    if engine1.url.database != new_database:
        result = mysqldb_basic.drop_create_db(engine1, new_database)
    else:
        print(
            "Error: names of the reference and frozen databases are the same.")
        print("No database will be created.")
        result = 1

    # Copy database.
    if result == 0:
        print(f"Reference database: {ref_database}")
        print(f"New database: {new_database}")
        result = mysqldb_basic.copy_db(engine1, new_database)
        if result == 0:
            print(f"Deleting genomes...")
            alchemist2 = AlchemyHandler(database=new_database,
                                        username=engine1.url.username,
                                        password=engine1.url.password)
            alchemist2.connect(pipeline=True)
            engine2 = alchemist2.engine
            engine2.execute(delete_stmt)
            if reset:
                engine2.execute(RESET_VERSION)

            # Close up all connections in the connection pool.
            engine2.dispose()
        else:
            print("Unable to copy the database.")
        # Close up all connections in the connection pool.
        engine1.dispose()
    else:
        print(f"Error creating new database: {new_database}.")
    print("Freeze database script completed.")
コード例 #24
0
ファイル: test_fileio.py プロジェクト: cdshaffer/pdm_utils
class TestFileIO(unittest.TestCase):
    @classmethod
    def setUpClass(self):
        base_dir = Path(TMPDIR_BASE)
        self.test_dir = base_dir.joinpath(TMPDIR_PREFIX)
        test_db_utils.create_filled_test_db()

        if self.test_dir.is_dir():
            shutil.rmtree(self.test_dir)

        self.test_dir.mkdir()

        self.test_import_table_1 = Path(test_file_dir,
                                        "test_import_table_1.csv")
        self.tkt_dict1 = {"phage_id": "L5", "host_genus": "Mycobacterium"}
        self.tkt_dict2 = {"phage_id": "Trixie", "host_genus": "Mycobacterium"}

        self.test_fasta_file_1 = Path(test_file_dir, "test_fasta_file_1.fasta")
        self.test_fasta_file_2 = Path(test_file_dir, "test_fasta_file_2.fasta")

        self.test_fa_1_gs_to_ts = {}
        self.test_fa_2_gs_to_ts = {}
        with self.test_fasta_file_1.open(mode="r") as filehandle:
            for record in SeqIO.parse(filehandle, "fasta"):
                self.test_fa_1_gs_to_ts[record.id] = str(record.seq)
                self.test_fa_2_gs_to_ts[record.id] = str(record.seq)

        self.test_fa_1_ts_to_gs = {}
        for seq_id, trans in self.test_fa_1_gs_to_ts.items():
            seq_ids = self.test_fa_1_ts_to_gs.get(trans, [])
            seq_ids.append(seq_id)
            self.test_fa_1_ts_to_gs[trans] = seq_ids

        self.test_fa_2_ts_to_gs = {}
        for seq_id, trans in self.test_fa_2_gs_to_ts.items():
            seq_ids = self.test_fa_2_ts_to_gs.get(trans, [])
            seq_ids.append(seq_id)
            self.test_fa_2_ts_to_gs[trans] = seq_ids

        self.fasta_dict_1 = {
            "Trixie_CDS_11": ("MASIQGKLIALVLKYGISYLRKHPELLKEI"
                              "SKHIPGKVDDLVLEVLAKLLGV")
        }
        self.fasta_dict_2 = {
            "TRIXIE_CDS_3": ("MSGFDDKIVDQAQAIVPADDYDALPLAGPGR"
                             "WAHVPGGLTLYTNDDTVLFAQGDMSTIESSY"
                             "LFQAMEKLRLAGKTASQAFDILRLEADAISG"
                             "DLSELAEE"),
            "L5_CDS_3": ("MAQMQATHTIEGFLAVEVAPRAFVAENGHVL"
                         "TRLSATKWGGGEGLEILNYEGPGTVEVSDEK"
                         "LAEAQRASEVEAELRREVGKE")
        }

    @classmethod
    def tearDownClass(self):
        test_db_utils.remove_db()
        shutil.rmtree(self.test_dir)

    def setUp(self):
        self.alchemist = AlchemyHandler()
        self.alchemist.username = USER
        self.alchemist.password = PWD
        self.alchemist.database = DB
        self.alchemist.connect(ask_database=True, login_attempts=0)
        self.alchemist.build_graph()

        self.fileio_test_dir = self.test_dir.joinpath("fileio_test_dir")
        self.fileio_test_dir.mkdir()
        self.data_dict_file = self.fileio_test_dir.joinpath("table.csv")
        self.fasta_file = self.fileio_test_dir.joinpath("translations.fasta")

    def tearDown(self):
        shutil.rmtree(self.fileio_test_dir)

    def test_retrieve_data_dict_1(self):
        """Verify a correctly structured file can be opened."""
        list_of_data_dicts = \
            fileio.retrieve_data_dict(self.test_import_table_1)
        self.assertEqual(len(list_of_data_dicts), 2)

    def test_export_data_dict_1(self):
        """Verify data is exported correctly."""

        list_of_data = [self.tkt_dict1, self.tkt_dict2]
        headers = ["type", "phage_id", "host_genus", "cluster"]
        fileio.export_data_dict(list_of_data,
                                self.data_dict_file,
                                headers,
                                include_headers=True)

        exp_success_tkts = []
        with open(self.data_dict_file, 'r') as file:
            file_reader = csv.DictReader(file)
            for dict in file_reader:
                exp_success_tkts.append(dict)

        with self.subTest():
            self.assertEqual(len(exp_success_tkts), 2)
        with self.subTest():
            self.assertEqual(set(exp_success_tkts[0].keys()), set(headers))

    def test_write_fasta_1(self):
        """Verify write_fasta() creates readable fasta formatted file"""
        fileio.write_fasta(self.fasta_dict_1, self.fasta_file)

        record = SeqIO.read(self.fasta_file, "fasta")
        id = list(self.fasta_dict_1.keys())[0]
        seq = self.fasta_dict_1[id]

        self.assertEqual(record.id, id)
        self.assertEqual(str(record.seq), seq)

    def test_write_fasta_2(self):
        """Verify write_fasta() can properly concatenate fasta files"""
        fileio.write_fasta(self.fasta_dict_2, self.fasta_file)

        records = SeqIO.parse(self.fasta_file, "fasta")

        keys = list(self.fasta_dict_2.keys())

        for record in records:
            self.assertTrue(record.id in keys)

            seq = self.fasta_dict_2[record.id]
            self.assertEqual(str(record.seq), seq)

    def test_reintroduce_fasta_duplicates_1(self):
        """Verify reintroduce_duplicates() copies fastas without duplicates"""
        fileio.write_fasta(self.test_fa_1_gs_to_ts, self.fasta_file)

        fileio.reintroduce_fasta_duplicates(self.test_fa_1_ts_to_gs,
                                            self.fasta_file)

        with self.fasta_file.open(mode="r") as filehandle:
            for record in SeqIO.parse(filehandle, "fasta"):
                with self.subTest(seq_id=record.id):
                    translation = self.test_fa_1_gs_to_ts[record.id]
                    self.assertTrue(translation is not None)
                    self.assertEqual(str(record.seq), translation)

    def test_reintroduce_fasta_duplicates_2(self):
        """Verify reintroduce_duplicates() recognizes duplicate sequences"""
        fileio.write_fasta(self.test_fa_1_gs_to_ts, self.fasta_file)

        fileio.reintroduce_fasta_duplicates(self.test_fa_2_ts_to_gs,
                                            self.fasta_file)

        with self.fasta_file.open(mode="r") as filehandle:
            for record in SeqIO.parse(filehandle, "fasta"):
                with self.subTest(seq_id=record.id):
                    translation = self.test_fa_2_gs_to_ts[record.id]
                    self.assertTrue(translation is not None)
                    self.assertEqual(str(record.seq), translation)

    def test_reintroduce_fasta_duplicates_3(self):
        """Verify reintroduce_duplicates() preserves unrecognized translations
        """
        fileio.write_fasta(self.test_fa_2_gs_to_ts, self.fasta_file)

        fileio.reintroduce_fasta_duplicates(self.test_fa_1_ts_to_gs,
                                            self.fasta_file)

        with self.fasta_file.open(mode="r") as filehandle:
            for record in SeqIO.parse(filehandle, "fasta"):
                with self.subTest(seq_id=record.id):
                    translation = self.test_fa_2_gs_to_ts[record.id]
                    self.assertTrue(translation is not None)
                    self.assertEqual(str(record.seq), translation)
コード例 #25
0
class TestFileExport(unittest.TestCase):
    @classmethod
    def setUpClass(self):
        test_db_utils.create_filled_test_db()

        self.test_dir = Path(TEST_DIR)
        if self.test_dir.is_dir():
            shutil.rmtree(TEST_DIR)

        self.test_dir.mkdir()

    @classmethod
    def tearDownClass(self):
        test_db_utils.remove_db()
        shutil.rmtree(TEST_DIR)

    def setUp(self):
        self.alchemist = AlchemyHandler()
        self.alchemist.username=USER
        self.alchemist.password=PWD
        self.alchemist.database=DB
        self.alchemist.connect(ask_database=True, login_attempts=0)
        self.alchemist.build_graph()

        self.db_filter = Filter(alchemist=self.alchemist)
        
        self.export_test_dir = self.test_dir.joinpath("export_test_dir")

    def tearDown(self):
        if self.export_test_dir.is_dir():
            shutil.rmtree(str(self.export_test_dir))

    def test_execute_export_1(self):
        """Verify execute_export() creates new directory as expected.
        """
        for pipeline in export_db.PIPELINES:
            with self.subTest(pipeline=pipeline):
                export_db.execute_export(self.alchemist, self.test_dir, 
                                         self.export_test_dir.name, pipeline)
                self.assertTrue(self.export_test_dir.is_dir())
                shutil.rmtree(str(self.export_test_dir))

    def test_execute_export_2(self):
        """Verify execute_export() 'sql' pipeline functions as expected.
        """
        export_db.execute_export(self.alchemist, self.test_dir,
                                  self.export_test_dir.name, "sql")

        self.assertTrue(self.export_test_dir.is_dir())
        
        sql_file_path = self.export_test_dir.joinpath(
                                            f"{self.alchemist.database}.sql")
        self.assertTrue(sql_file_path.is_file())

    def test_execute_export_3(self):
        """Verify execute_export() 'csv' pipeline functions as expected.
        """
        for table in export_db.TABLES:
            with self.subTest(table=table):
                export_db.execute_export(self.alchemist, self.test_dir,
                                          self.export_test_dir.name, "csv",
                                          table=table)
                self.assertTrue(self.export_test_dir.is_dir())

                csv_file_path = self.export_test_dir.joinpath(
                                            f"{table}.csv")

                self.assertTrue(csv_file_path.is_file())

                shutil.rmtree(str(self.export_test_dir))

    def test_execute_export_4(self):
        """Verify execute_export() SeqRecord pipelines function as expected.
        """
        for file_type in export_db.BIOPYTHON_PIPELINES:
            with self.subTest(file_type=file_type):
                export_db.execute_export(self.alchemist, self.test_dir,
                                         self.export_test_dir.name, file_type)
                self.assertTrue(self.export_test_dir.is_dir())

                flat_file_path = self.export_test_dir.joinpath(
                                            f"Trixie.{file_type}")
                self.assertTrue(flat_file_path.is_file())

                shutil.rmtree(str(self.export_test_dir))
   
    def test_execute_export_5(self):
        """Verify execute_export() filter parameter functions as expected.
        """
        filters = "phage.PhageID!=Trixie AND phage.Cluster=A"
        export_db.execute_export(self.alchemist, self.test_dir,
                                 self.export_test_dir.name, "fasta",
                                 filters=filters)

        D29_file_path = self.export_test_dir.joinpath("D29.fasta")
        Trixie_file_path = self.export_test_dir.joinpath("Trixie.fasta")

        self.assertTrue(D29_file_path.is_file())
        self.assertFalse(Trixie_file_path.is_file())
    
    def test_execute_export_6(self):
        """Verify execute_export() group parameter functions as expected.
        """
        groups = ["phage.Cluster", "phage.Subcluster"]
        export_db.execute_export(self.alchemist, self.test_dir,
                                 self.export_test_dir.name, "fasta",
                                 groups=groups)

        A_path = self.export_test_dir.joinpath("A")
        C_path = self.export_test_dir.joinpath("C")

        A2_path = A_path.joinpath("A2")
        C1_path = C_path.joinpath("C1")
        C2_path = C_path.joinpath("C2")

        Trixie_path = A2_path.joinpath("Trixie.fasta")
        D29_path = A2_path.joinpath("D29.fasta")
        Alice_path = C1_path.joinpath("Alice.fasta")
        Myrna_path = C2_path.joinpath("Myrna.fasta")

                            
        self.assertTrue(A_path.is_dir())
        self.assertTrue(C_path.is_dir())

        self.assertTrue(A2_path.is_dir())
        self.assertTrue(C1_path.is_dir())
        self.assertTrue(C2_path.is_dir())

        self.assertTrue(Trixie_path.is_file())
        self.assertTrue(D29_path.is_file())
        self.assertTrue(Alice_path.is_file())
        self.assertTrue(Myrna_path.is_file())
        
    def test_execute_export_7(self):
        """Verify execute_export() sort parameter is functional.
        """
        sort_columns = ["phage.Subcluster"]
        export_db.execute_export(self.alchemist, self.test_dir,
                                 self.export_test_dir.name, "csv",
                                 sort=sort_columns)
    
    def test_execute_export_8(self):
        """Verify execute_export() concatenate parameter functions as expected.
        """
        export_db.execute_export(self.alchemist, self.test_dir,
                                 self.export_test_dir.name, "fasta",
                                 concatenate=True)

        fasta_path = self.export_test_dir.joinpath(
                                        f"{self.export_test_dir.name}.fasta")

        self.assertTrue(fasta_path.is_file())

    def test_execute_export_9(self):
        """Verify execute_export() include_columns functions as expected.
        """
        include_columns = ["phage.Cluster"]
        export_db.execute_export(self.alchemist, self.test_dir,
                                 self.export_test_dir.name, "csv", table="gene",
                                 include_columns=include_columns)

        csv_path = self.export_test_dir.joinpath(
                                        f"gene.csv")

        with open(csv_path) as csv_handle:
            reader = csv.reader(csv_handle)
            headers = next(reader)

        self.assertTrue("Cluster" in headers)
        self.assertEqual("GeneID", headers[0])
        self.assertFalse("Translation" in headers)
    
    def test_execute_export_10(self):
        """Verify execute_export() exclude_columns functions as expected.
        """
        exclude_columns = ["phage.Subcluster"]
        export_db.execute_export(self.alchemist, self.test_dir,
                                 self.export_test_dir.name, "csv",
                                 exclude_columns=exclude_columns)
        
        csv_path = self.export_test_dir.joinpath(
                                        f"phage.csv")

        with open(csv_path) as csv_handle:
            reader = csv.reader(csv_handle)
            headers = next(reader)

        self.assertTrue("Cluster" in headers)
        self.assertEqual("PhageID", headers[0])
        self.assertFalse("Subcluster" in headers)
        self.assertFalse("Sequence" in headers)

    def test_execute_export_11(self):
        """Verify execute_export() sequence_columns functions as expected.
        """
        export_db.execute_export(self.alchemist, self.test_dir,
                                 self.export_test_dir.name, "csv",
                                 sequence_columns=True)

        csv_path = self.export_test_dir.joinpath(
                                        f"phage.csv")

        with open(csv_path) as csv_handle:
            reader = csv.reader(csv_handle)
            headers = next(reader)

        self.assertTrue("Cluster" in headers)
        self.assertEqual("PhageID", headers[0])
        self.assertTrue("Sequence" in headers)

    def test_execute_export_12(self):
        """Verify execute_export() SeqRecord pipeline functions as expected.
        """
        for file_type in export_db.BIOPYTHON_PIPELINES:
            with self.subTest(file_type=file_type):
                export_db.execute_export(self.alchemist, self.test_dir,
                                         self.export_test_dir.name, file_type,
                                         table="gene")

                self.assertTrue(self.export_test_dir.is_dir())

                flat_file_path = self.export_test_dir.joinpath(
                                            f"Trixie_CDS_70.{file_type}")
                self.assertTrue(flat_file_path.is_file())

                shutil.rmtree(str(self.export_test_dir))
コード例 #26
0
ファイル: phamerate.py プロジェクト: stjacqrm/pdm_utils
def main(argument_list):
    # Set up the argument parser
    phamerate_parser = setup_argparser()

    # Parse arguments
    args = phamerate_parser.parse_args(argument_list)
    program = args.program
    temp_dir = args.temp_dir

    # Initialize SQLAlchemy engine with database provided at CLI
    alchemist = AlchemyHandler(database=args.db)
    alchemist.connect(pipeline=True)
    engine = alchemist.engine


    # If we made it past the above connection_status() check, database access
    # works (user at least has SELECT privileges on the indicated database).
    # We'll assume that they also have UPDATE, INSERT, and TRUNCATE privileges.

    # Record start time
    start = datetime.datetime.now()

    # Refresh temp_dir
    if os.path.exists(temp_dir):
        try:
            shutil.rmtree(temp_dir)
        except OSError:
            print(f"Failed to delete existing temp directory '{temp_dir}'")
            return
    try:
        os.makedirs(temp_dir)
    except OSError:
        print(f"Failed to create new temp directory '{temp_dir}")
        return

    # Get old pham data and un-phamerated genes
    old_phams = get_pham_geneids(engine)
    old_colors = get_pham_colors(engine)
    unphamerated = get_new_geneids(engine)

    # Get GeneIDs & translations, and translation groups
    genes_and_trans = map_geneids_to_translations(engine)
    translation_groups = map_translations_to_geneids(engine)

    # Write input fasta file
    write_fasta(translation_groups, temp_dir)

    # Create clusterdb and perform clustering
    program_params = get_program_params(program, args)
    create_clusterdb(program, temp_dir)
    phamerate(program_params, program, temp_dir)

    # Parse phameration output
    new_phams = parse_output(program, temp_dir)
    new_phams = reintroduce_duplicates(new_phams, translation_groups, genes_and_trans)

    # Preserve old pham names and colors
    new_phams, new_colors = preserve_phams(old_phams, new_phams, old_colors,
                                           unphamerated)

    # Early exit if we don't have new phams or new colors - avoids
    # overwriting the existing pham data with potentially incomplete new data
    if len(new_phams) == 0 or len(new_colors) == 0:
        print("Failed to parse new pham/color data... Terminating pipeline")
        return

    # If we got past the early exit, we are probably safe to truncate the
    # pham table, and insert the new pham data
    # Clear old pham data - auto commits at end of transaction - this will also
    # set all PhamID values in gene table to NULL
    commands = ["DELETE FROM pham"]
    mysqldb.execute_transaction(engine, commands)


    # Insert new pham/color data
    reinsert_pham_data(new_phams, new_colors, engine)

    # Fix miscolored phams/orphams
    fix_miscolored_phams(engine)

    # Close all connections in the connection pool.
    engine.dispose()

    # Record stop time
    stop = datetime.datetime.now()

    # Report phameration elapsed time
    print("Elapsed time: {}".format(str(stop - start)))