Exemple #1
0
class BestScoreSolverTest(unittest.TestCase):

    # setUp is executed before each test method
    def setUp(self):
        '''
        @param mock_db_fpath (str) path to syntheticaly created CDSs which serves
        to fill up mock database of records
        @param input_aln_fpath (str) path to input alignment file 
        @param results_fpath (str) path to file with generated correct results 
        greedy solver should generate
        '''
        self.mock_db_fpath = './test/solver/read2cds/.test_data/cds.fa'
        self.input_aln_fpath = './test/solver/read2cds/.test_data/lisa.in'
        self.results_fpath = './test/solver/read2cds/.test_data/cds_ordering.txt'
        #       Initialize read container
        self.read_cont = ReadContainer()
        self.read_cont.populate_from_aln_file(self.input_aln_fpath)
        #       Initialize and fill record container
        self.db_query = MockDbQuery(self.mock_db_fpath)
        self.record_cont = RecordContainer()
        self.record_cont.set_db_access(self.db_query)
        self.record_cont.populate(self.read_cont.fetch_all_reads_versions())
        self.read_cont.populate_cdss(self.record_cont)
        #       Initialize and fill up cds aln container
        self.cds_aln_cont = CdsAlnContainer()
        self.cds_aln_cont.populate(self.read_cont.fetch_all_reads())

        self.bs_solver = BestScoreSolver()
        self.bs_solver.map_reads_2_cdss(self.cds_aln_cont)

    def testCdsAlignmentContainerConsistency(self):
        assert (Read2CDSSolver.test_cds_alignment_container_consistency(
            self.cds_aln_cont) == True)
Exemple #2
0
    def setUp(self):
        '''
        @param mock_db_fpath (str) path to syntheticaly created CDSs which serves
        to fill up mock database of records
        @param input_aln_fpath (str) path to input alignment file 
        @param results_fpath (str) path to file with generated correct results 
        greedy solver should generate
        '''
        self.mock_db_fpath = './test/solver/read2cds/.test_data/cds.fa'
        self.input_aln_fpath = './test/solver/read2cds/.test_data/lisa.in'
        self.results_fpath = './test/solver/read2cds/.test_data/cds_ordering.txt'
        #       Initialize read container
        self.read_cont = ReadContainer()
        self.read_cont.populate_from_aln_file(self.input_aln_fpath)
        #       Initialize and fill record container
        self.db_query = MockDbQuery(self.mock_db_fpath)
        self.record_cont = RecordContainer()
        self.record_cont.set_db_access(self.db_query)
        self.record_cont.populate(self.read_cont.fetch_all_reads_versions())
        self.read_cont.populate_cdss(self.record_cont)
        #       Initialize and fill up cds aln container
        self.cds_aln_cont = CdsAlnContainer()
        self.cds_aln_cont.populate(self.read_cont.fetch_all_reads())

        self.greedy_solver = GreedySolver()
        self.greedy_solver.map_reads_2_cdss(self.cds_aln_cont)
Exemple #3
0
    def setUp(self):
        '''
        @param mock_db_fpath (str) path to syntheticaly created CDSs which serves
        to fill up mock database of records
        @param input_aln_fpath (str) path to input alignment file 
        @param results_fpath (str) path to file with generated correct results 
        greedy solver should generate
        '''        
        self.mock_db_fpath = './test/solver/read2cds/.test_data/cds.fa'
        self.input_aln_fpath = './test/solver/read2cds/.test_data/lisa.in'
        self.results_fpath = './test/solver/read2cds/.test_data/cds_ordering.txt'
#       Initialize read container
        self.read_cont = ReadContainer()
        self.read_cont.populate_from_aln_file(self.input_aln_fpath)
#       Initialize and fill record container
        self.db_query = MockDbQuery (self.mock_db_fpath)
        self.record_cont = RecordContainer()
        self.record_cont.set_db_access(self.db_query)
        self.record_cont.populate(self.read_cont.fetch_all_reads_versions())
        self.read_cont.populate_cdss(self.record_cont)
#       Initialize and fill up cds aln container
        self.cds_aln_cont = CdsAlnContainer()
        self.cds_aln_cont.populate(self.read_cont.fetch_all_reads())

        self.greedy_solver = GreedySolver()
        self.greedy_solver.map_reads_2_cdss(self.cds_aln_cont)
Exemple #4
0
class StatisticsTest(unittest.TestCase):

    # setUp is executed before each test method
    def setUp(self):
        '''
        @param mock_db_fpath (str) path to syntheticaly created CDSs which serves
        to fill up mock database of records
        @param input_aln_fpath (str) path to input alignment file 
        @param results_fpath (str) path to file with generated correct results 
        greedy solver should generate
        '''
        self.mock_db_fpath = './test/statistics/.test_data/cds.fa'
        self.input_aln_fpath = './test/statistics/.test_data/lisa.in'
        #       Initialize read container
        self.read_cont = ReadContainer()
        self.read_cont.populate_from_aln_file(self.input_aln_fpath)
        #       Initialize and fill record container
        self.db_query = MockDbQuery(self.mock_db_fpath)
        self.record_cont = RecordContainer()
        self.record_cont.set_db_access(self.db_query)
        self.record_cont.populate(self.read_cont.fetch_all_reads_versions())
        self.read_cont.populate_cdss(self.record_cont)
        #       Initialize and fill up cds aln container
        self.cds_aln_cont = CdsAlnContainer()
        self.cds_aln_cont.populate(self.read_cont.fetch_all_reads())

    def testStatistics(self):
        assert (num_read_alns(self.read_cont) == 22)
        assert (num_active_aligned_regions(self.cds_aln_cont) == 22)
        assert (num_inactive_read_alns(self.read_cont) == 0)

        self.bs_solver = BestScoreSolver()
        self.bs_solver.map_reads_2_cdss(self.cds_aln_cont)

        records_stats = count_alns_to_record_and_cds(self.read_cont)
        print "Number of records for which we  have stats: %d\n" % len(
            records_stats)
        for rec_stat in records_stats.values():
            rec_stat.print_data()

        assert (num_active_aligned_regions(self.cds_aln_cont) == 16)

        assert (num_cdss(self.cds_aln_cont) == 4)
        assert (num_cdss_with_no_alns(self.cds_aln_cont) == 0)
Exemple #5
0
class StatisticsTest (unittest.TestCase):

    # setUp is executed before each test method
    def setUp(self):
        '''
        @param mock_db_fpath (str) path to syntheticaly created CDSs which serves
        to fill up mock database of records
        @param input_aln_fpath (str) path to input alignment file 
        @param results_fpath (str) path to file with generated correct results 
        greedy solver should generate
        '''        
        self.mock_db_fpath = './test/statistics/.test_data/cds.fa'
        self.input_aln_fpath = './test/statistics/.test_data/lisa.in'
#       Initialize read container
        self.read_cont = ReadContainer()
        self.read_cont.populate_from_aln_file(self.input_aln_fpath)
#       Initialize and fill record container
        self.db_query = MockDbQuery (self.mock_db_fpath)
        self.record_cont = RecordContainer()
        self.record_cont.set_db_access(self.db_query)
        self.record_cont.populate(self.read_cont.fetch_all_reads_versions())
        self.read_cont.populate_cdss(self.record_cont)
#       Initialize and fill up cds aln container
        self.cds_aln_cont = CdsAlnContainer()
        self.cds_aln_cont.populate(self.read_cont.fetch_all_reads())


    def testStatistics(self):
        assert(num_read_alns(self.read_cont) == 22)
        assert(num_active_aligned_regions(self.cds_aln_cont) == 22)
        assert(num_inactive_read_alns(self.read_cont) == 0)
        
        self.bs_solver = BestScoreSolver()
        self.bs_solver.map_reads_2_cdss(self.cds_aln_cont)
        
        records_stats = count_alns_to_record_and_cds(self.read_cont)
        print "Number of records for which we  have stats: %d\n" % len(records_stats)
        for rec_stat in records_stats.values():
            rec_stat.print_data()

        assert(num_active_aligned_regions(self.cds_aln_cont) == 16)
        
        assert(num_cdss(self.cds_aln_cont) == 4)
        assert(num_cdss_with_no_alns(self.cds_aln_cont) == 0)
class RecordContainerTest (unittest.TestCase):

    def setUp(self):
        self.read_container = ReadContainer()
        self.record_container = RecordContainer()

    def tearUp(self):
        pass

    def testFillRecordContainer(self):
        '''Method to test whether record container populating works.
        Uses mock database access to test whether record container
        has correct number of items.'''
        aln_file = './test/solver/read2cds/.test_data/lisa.in'
        cds_fasta = './test/solver/read2cds/.test_data/cds.fa'
        db_access = MockDbQuery(cds_fasta)
        self.record_container.set_db_access(db_access)

        self.read_container.populate_from_aln_file(aln_file)
        self.record_container.populate(
            self.read_container.fetch_all_reads_versions())
        records = self.record_container.fetch_all_records(format=list)
        self.assertEqual (len(db_access.records), len(records))
        
    def testReturnsNoneForNonexistentRecord(self):
        record = self.record_container.fetch_existing_record("XXX")
        self.assertIsNone(record, "No record with version XXX should be found")
Exemple #7
0
class RecordContainerTest(unittest.TestCase):
    def setUp(self):
        self.read_container = ReadContainer()
        self.record_container = RecordContainer()

    def tearUp(self):
        pass

    def testFillRecordContainer(self):
        '''Method to test whether record container populating works.
        Uses mock database access to test whether record container
        has correct number of items.'''
        aln_file = './test/solver/read2cds/.test_data/lisa.in'
        cds_fasta = './test/solver/read2cds/.test_data/cds.fa'
        db_access = MockDbQuery(cds_fasta)
        self.record_container.set_db_access(db_access)

        self.read_container.populate_from_aln_file(aln_file)
        self.record_container.populate(
            self.read_container.fetch_all_reads_versions())
        records = self.record_container.fetch_all_records(format=list)
        self.assertEqual(len(db_access.records), len(records))

    def testReturnsNoneForNonexistentRecord(self):
        record = self.record_container.fetch_existing_record("XXX")
        self.assertIsNone(record, "No record with version XXX should be found")
Exemple #8
0
def fill_containers (alignment_file, db_access):
    '''
    Populates read, record and CDS alignment container.
    @return tuple(ReadContainer, RecordContainer, CdsAlnContainer)
    '''

    read_cont   = ReadContainer()
    record_cont = RecordContainer()
    record_cont.set_db_access(db_access)
    cdsaln_cont = CdsAlnContainer()

#   1. Load all the information available in the alignment file
    read_cont.populate_from_aln_file(alignment_file)
#   2. Fetch all the records reported in the alignment file from the database
    record_cont.populate(read_cont.fetch_all_reads_versions())
#   3. Find to which coding sequences reads map
    read_cont.populate_cdss(record_cont)
#   4. Populate Cds Alignment container
    cdsaln_cont.populate(read_cont.fetch_all_reads())

    return (read_cont, record_cont, cdsaln_cont)
def main():
    '''
    Script to run binner in one of the most common
    usage scenarios.
    * load alignment data
    * load taxonomy data
    * do basic alignment data filtering (remove host reads ecc)
    '''

    #----------------------------------#
    #------ INPUT ARGUMENTS -----------#
    argparser = PickleParser()
    args  = argparser.parse_args()

    #----------------------------------#
    #------- STATIC DATA SOURCE -------#
    # CDS - GI2TAXID -- NAMES -- NODES #
    dataAccess = DataAccess(args)
    #raw_input('Data access created')
    #----------------------------------#

    #-------- TAXONOMY TREE -----------#
    print '1. Loading tax tree...'
    tax_tree = TaxTree()
    # tax_tree.load_taxonomy_data(dataAccess)
    print 'done.'

    #----------------------------------#
    #------- ALIGNMENT DATA SOURCE ----#
    print '2. Loading alignment file...'
    read_container = ReadContainer()
    read_container.load_alignment_data(args.input)
    #---SET TAXIDS FOR ALL ALIGNMENTS--#
    read_container.set_taxids(dataAccess)
    print 'done'

    #------- FILTER HOST READS -------#
    print '3. Filtering host reads & alignments...'
    new_reads = host_filter.filter_potential_host_reads(
        read_container.fetch_all_reads(format=list),
        tax_tree.tax2relevantTax,
        tax_tree.potential_hosts,
        #delete_host_alignments =
        True,
        #filter_unassigned =
        True,
        #unassigned_taxid=
        -1,
        host_filter.perc_of_host_alignments_larger_than)
    dataAccess.clear_cache()    # deletes gi2taxid cache
    reads_with_no_host_alignments = host_filter.filter_potential_hosts_alignments(
        new_reads,
        tax_tree.tax2relevantTax,
        tax_tree.potential_hosts,
        True,   # delete host alignments
        True,   # filter unassigned
        -1)     # unassigned taxid
    host_read_count = len(read_container.fetch_all_reads(format=list)) - len(reads_with_no_host_alignments)
    read_container.set_new_reads(reads_with_no_host_alignments)
    print 'done'

    #----------------------------------#
    #------- LOAD ALL RECORDS   -------#
    print '4. Loading referenced records...'
    record_container = RecordContainer()
    record_container.set_db_access(dataAccess)
    record_container.populate(read_container.fetch_all_reads_versions(), table='cds')
    record_container.populate(read_container.fetch_all_reads_versions(), table='rrna')
    print 'done'
    #----------------------------------#
    #-- MAP ALIGNMENTS TO GENES   -----#
    print '5. Mapping alignments to genes...'
    read_container.populate_cdss(record_container)
    #----------------------------------#
    #- RECORD ALL ALIGNEMENTS TO GENE -#
    cds_aln_container = CdsAlnContainer()
    cds_aln_container.populate(read_container.fetch_all_reads(format=list))
    print 'done'
Exemple #10
0
class GreedySolverTest(unittest.TestCase):

    # setUp is executed before each test method
    def setUp(self):
        '''
        @param mock_db_fpath (str) path to syntheticaly created CDSs which serves
        to fill up mock database of records
        @param input_aln_fpath (str) path to input alignment file 
        @param results_fpath (str) path to file with generated correct results 
        greedy solver should generate
        '''
        self.mock_db_fpath = './test/solver/read2cds/.test_data/cds.fa'
        self.input_aln_fpath = './test/solver/read2cds/.test_data/lisa.in'
        self.results_fpath = './test/solver/read2cds/.test_data/cds_ordering.txt'
        #       Initialize read container
        self.read_cont = ReadContainer()
        self.read_cont.populate_from_aln_file(self.input_aln_fpath)
        #       Initialize and fill record container
        self.db_query = MockDbQuery(self.mock_db_fpath)
        self.record_cont = RecordContainer()
        self.record_cont.set_db_access(self.db_query)
        self.record_cont.populate(self.read_cont.fetch_all_reads_versions())
        self.read_cont.populate_cdss(self.record_cont)
        #       Initialize and fill up cds aln container
        self.cds_aln_cont = CdsAlnContainer()
        self.cds_aln_cont.populate(self.read_cont.fetch_all_reads())

        self.greedy_solver = GreedySolver()
        self.greedy_solver.map_reads_2_cdss(self.cds_aln_cont)

    def testAlignmentsCorrectlyInactivated(self):
        '''
        Loads correct results from results file and checks whether 
        all the reads for a CDS listed in the file are active and
        whether all the other reads are inactive.
        '''
        cds2read = self._load_active_reads()

        for (cds, cds_aln) in self.cds_aln_cont.cds_repository.items():
            accession = cds.record_id
            mapped_reads = cds2read[accession]
            for cds_aln_subloc in cds_aln.aligned_regions.values():
                if cds_aln_subloc.active:
                    assert (cds_aln_subloc.read_id in mapped_reads)
                else:
                    assert (cds_aln_subloc.read_id not in mapped_reads)

    def testCdsAlignmentContainerConsistency(self):
        assert (Read2CDSSolver.test_cds_alignment_container_consistency(
            self.cds_aln_cont) == True)

    def _load_active_reads(self):
        results_fhandle = open(self.results_fpath)
        lines = iter(results_fhandle.readlines())
        cds2read_map = {}
        while (True):
            cds_id = next(lines, None)
            read_ids = next(lines, None)
            if not cds_id: break
            cds2read_map[cds_id.strip()] = read_ids.strip().split(';')
        results_fhandle.close()
        return cds2read_map
Exemple #11
0
from ncbi.db.access import DbQuery
from solver.read2cds.GreedySolver import GreedySolver
from solver.read2cds.BestScoreSolver import BestScoreSolver
from solver.read2cds.Read2CDSSolver import Read2CDSSolver
from data.containers.record import RecordContainer
from data.containers.read import ReadContainer
from data.containers.cdsaln import CdsAlnContainer

from utils.logger import Logger

Logger("log")

db_query = DbQuery()

# create containers
record_container = RecordContainer()
record_container.set_db_access(db_query)

read_container = ReadContainer()
read_container.populate_from_aln_file("example_data/2reads.in")

record_container.populate(read_container)
read_container.populate_cdss(record_container)

cds_aln_container = CdsAlnContainer()
cds_aln_container.populate(read_container)

print cds_aln_container


r2c_solver = BestScoreSolver()
Exemple #12
0
class GreedySolverTest (unittest.TestCase):

    # setUp is executed before each test method
    def setUp(self):
        '''
        @param mock_db_fpath (str) path to syntheticaly created CDSs which serves
        to fill up mock database of records
        @param input_aln_fpath (str) path to input alignment file 
        @param results_fpath (str) path to file with generated correct results 
        greedy solver should generate
        '''        
        self.mock_db_fpath = './test/solver/read2cds/.test_data/cds.fa'
        self.input_aln_fpath = './test/solver/read2cds/.test_data/lisa.in'
        self.results_fpath = './test/solver/read2cds/.test_data/cds_ordering.txt'
#       Initialize read container
        self.read_cont = ReadContainer()
        self.read_cont.populate_from_aln_file(self.input_aln_fpath)
#       Initialize and fill record container
        self.db_query = MockDbQuery (self.mock_db_fpath)
        self.record_cont = RecordContainer()
        self.record_cont.set_db_access(self.db_query)
        self.record_cont.populate(self.read_cont.fetch_all_reads_versions())
        self.read_cont.populate_cdss(self.record_cont)
#       Initialize and fill up cds aln container
        self.cds_aln_cont = CdsAlnContainer()
        self.cds_aln_cont.populate(self.read_cont.fetch_all_reads())

        self.greedy_solver = GreedySolver()
        self.greedy_solver.map_reads_2_cdss(self.cds_aln_cont)


    def testAlignmentsCorrectlyInactivated(self):
        '''
        Loads correct results from results file and checks whether 
        all the reads for a CDS listed in the file are active and
        whether all the other reads are inactive.
        '''
        cds2read = self._load_active_reads()

        for (cds, cds_aln) in self.cds_aln_cont.cds_repository.items():
            accession = cds.record_id
            mapped_reads = cds2read[accession]
            for cds_aln_subloc in cds_aln.aligned_regions.values():
                if cds_aln_subloc.active:
                    assert (cds_aln_subloc.read_id in mapped_reads)
                else:
                    assert (cds_aln_subloc.read_id not in mapped_reads)

    def testCdsAlignmentContainerConsistency(self):
        assert(Read2CDSSolver.test_cds_alignment_container_consistency(self.cds_aln_cont) == True)

    def _load_active_reads (self):
        results_fhandle = open(self.results_fpath)
        lines = iter(results_fhandle.readlines())
        cds2read_map = {}
        while (True):
            cds_id = next(lines, None)
            read_ids = next(lines, None)
            if not cds_id: break
            cds2read_map[cds_id.strip()] = read_ids.strip().split(';')
        results_fhandle.close()
        return cds2read_map
Exemple #13
0
    # Populate read container
    # The read container type can now be determined from the input parameters
    # and injected into the Solver
    start = timing.start()
    read_container = ReadContainer()
    read_container.populate_from_aln_file(read_alignment_file=args.input)
    elapsed_time = timing.end(start)
    log.info("Populate read container - elapsed time: %s", 
             timing.humanize(elapsed_time))    
    
    # Populate record container
    # The record container type can now be determine from the input parameters
    # and injected into the Solver
    start = timing.start()
    record_container = RecordContainer()
    record_container.set_db_access(db_access)
    # Extract all records from database
    record_container.populate(read_container.fetch_all_reads_versions())
    elapsed_time = timing.end(start)
    log.info("Populate record container - elapsed time: %s", 
             timing.humanize(elapsed_time)) 
   
    solver.generateSolutionXML(read_container=read_container,
                               record_container=record_container,
                               dataset_xml_file=args.descr,
                               output_solution_filename=args.output,
                               stats_dir=args.stats_dir,
                               solution_file=args.solution_file)
    
    processing_delta = timing.end(processing_start)
Exemple #14
0
 def setUp(self):
     self.read_container = ReadContainer()
     self.record_container = RecordContainer()
Exemple #15
0
    # Populate read container
    # The read container type can now be determined from the input parameters
    # and injected into the Solver
    start = timing.start()
    read_container = ReadContainer()
    read_container.populate_from_aln_file(read_alignment_file=args.input)
    elapsed_time = timing.end(start)
    log.info("Populate read container - elapsed time: %s",
             timing.humanize(elapsed_time))

    # Populate record container
    # The record container type can now be determine from the input parameters
    # and injected into the Solver
    start = timing.start()
    record_container = RecordContainer()
    record_container.set_db_access(db_access)
    # Extract all records from database
    record_container.populate(read_container.fetch_all_reads_versions())
    elapsed_time = timing.end(start)
    log.info("Populate record container - elapsed time: %s",
             timing.humanize(elapsed_time))

    solver.generateSolutionXML(read_container=read_container,
                               record_container=record_container,
                               dataset_xml_file=args.descr,
                               output_solution_filename=args.output,
                               stats_dir=args.stats_dir,
                               solution_file=args.solution_file)

    processing_delta = timing.end(processing_start)
Exemple #16
0
def main():
    '''
    Script to run binner in one of the most common
    usage scenarios.
    * load alignment data
    * load taxonomy data
    * do basic alignment data filtering (remove host reads ecc)
    '''

    #----------------------------------#
    #------ INPUT ARGUMENTS -----------#
    argparser = TestRunArgParser()
    args  = argparser.parse_args()

    #----------------------------------#
    #------- STATIC DATA SOURCE -------#
    # CDS - GI2TAXID -- NAMES -- NODES #
    dataAccess = DataAccess(args)
    #raw_input('Data access created')
    #----------------------------------#

    #-------- TAXONOMY TREE -----------#
    print '1. Loading tax tree...'
    tax_tree = TaxTree()
    # tax_tree.load_taxonomy_data(dataAccess)
    print 'done.'

    #----------------------------------#
    #------- ALIGNMENT DATA SOURCE ----#
    print '2. Loading alignment file...'
    read_container = ReadContainer()
    read_container.load_alignment_data(args.input)
    #---SET TAXIDS FOR ALL ALIGNMENTS--#
    read_container.set_taxids(dataAccess)
    # Remember total number of reads
    total_read_num = read_container.get_read_count()
    print 'done'

    #------- FILTER HOST READS -------#
    print '3. Filtering host reads & alignments...'
    new_reads = host_filter.filter_potential_host_reads(
        read_container.fetch_all_reads(format=list),
        tax_tree.tax2relevantTax,
        tax_tree.potential_hosts,
        #delete_host_alignments =
        True,
        #filter_unassigned =
        True,
        #unassigned_taxid=
        -1,
        host_filter.perc_of_host_alignments_larger_than)

    dataAccess.clear_cache()    # deletes gi2taxid cache

    reads_with_no_host_alignments = host_filter.filter_potential_hosts_alignments(
        new_reads,
        tax_tree.tax2relevantTax,
        tax_tree.potential_hosts,
        True,   # delete host alignments
        True,   # filter unassigned
        -1)     # unassigned taxid

    host_read_count = len(read_container.fetch_all_reads(format=list)) - len(reads_with_no_host_alignments)
    read_container.set_new_reads(reads_with_no_host_alignments)
    print 'done'

    #----------------------------------#
    #------- LOAD ALL RECORDS   -------#
    print '4. Loading referenced records...'
    record_container = RecordContainer()
    record_container.set_db_access(dataAccess)
    record_container.populate(read_container.fetch_all_reads_versions(), table='cds')
    print 'done'
    #----------------------------------#
    #-- MAP ALIGNMENTS TO GENES   -----#
    print '5. Mapping alignments to genes...'
    read_container.populate_cdss(record_container)
    #----------------------------------#
    #- RECORD ALL ALIGNEMENTS TO GENE -#
    cds_aln_container = CdsAlnContainer()
    cds_aln_container.populate(read_container.fetch_all_reads(format=list))
    print 'done'

    print '6. Estimating organisms present in sample...'
    target_organisms = [633, 632, 263, 543, 86661, 1392, 55080, 1386] # What is this part?
    print 'done.'
   
    print '7. Annotating reads...' 
    annotated_reads = rstate.annotate_reads(
                    read_container.fetch_all_reads(format=list),
                    cds_aln_container.read2cds, 
                    tax_tree, 
                    target_organisms)
    read_container.set_new_reads(annotated_reads)
    print 'done'
   
    print '8. Binning reads...' 
    orgs = bin_reads(
        read_container.fetch_all_reads(format=list),
        cds_aln_container.cds_repository,
        cds_aln_container.read2cds, 
        tax_tree,
        target_organisms,
        None,
        None,
        False) 

    '''
    for org in orgs.values():
        print org.name
        print len(set(org.get_reads()))
        print len(org.identified_coding_regions)
    print 'done.'
    '''

    print ("total_read_num: " + str(total_read_num))

    print '9. Generating XML...'
    dataset = Dataset(args.xml_description_file)
    xml_organisms = []
    host = Organism (host_read_count, host_read_count/float(total_read_num), None, None, "Host",
                 None, None, [], [], [], is_host=True)
    xml_organisms.append(host)
    for org in orgs.values():
        xml_organisms.append(org.to_xml_organism(tax_tree, total_read_num))
    xml_organisms.sort(key=operator.attrgetter("amount_count"), reverse=True)
    xml = XMLOutput(dataset, xml_organisms, args.output) 
    xml.xml_output();
 def setUp(self):
     self.read_container = ReadContainer()
     self.record_container = RecordContainer()
def main():
    '''
    Script to analyse  genes expressed for given tax id.
    '''

    # Input arguments
    argparser   = ArgParser()
    args        = argparser.parse_args()

    # Access database
    dataAccess = DataAccess(args)

    print '1. Loading tax tree...'
    tax_tree = TaxTree()
    print 'done.'

    print '2. Loading alignment file...'
    read_container = ReadContainer()
    read_container.load_alignment_data(args.alignment_file)
    #---SET TAXIDS FOR ALL ALIGNMENTS--#
    read_container.set_taxids(dataAccess)
    print 'done'

    # TODO: Here i should recognize host reads!

    # ------------------------------------- #

    #----------------------------------#
    #------- LOAD ALL RECORDS   -------#
    print '4. Loading referenced records...'
    record_container = RecordContainer()
    record_container.set_db_access(dataAccess)
    record_container.populate(read_container.fetch_all_reads_versions(), table='cds')
    print 'done'
    #----------------------------------#
    #-- MAP ALIGNMENTS TO GENES   -----#
    print '5. Mapping alignments to genes...'
    read_container.populate_cdss(record_container)
    #----------------------------------#
    #- RECORD ALL ALIGNEMENTS TO GENE -#
    print '6. Populating CDS container...'
    cds_aln_container = CdsAlnContainer()
    cds_aln_container.populate(read_container.fetch_all_reads(format=list))
    print 'done'

    print("Loaded CDS container")
    
    # Take only CDSs of given tax_id
    # Remove CDSs with too low mean coverage value
    min_mean_coverage   = 10
    min_length          = 20

    cds_alns = cds_aln_container.fetch_all_cds_alns(format=list)
    print ( "All CDSs (all organisms): " + str(len(cds_alns)) )

    cds_alns_targeted = [cds_aln for cds_aln in cds_alns 
                         if  cds_aln.get_tax_id() == args.tax_id
                         # Filters
                         and cds_aln.get_cds_length() > min_length
                         and cds_aln.get_mean_coverage() > min_mean_coverage]

    # Remove CDSs with no gene/product
    cds_alns_targeted = [cds_aln for cds_aln in cds_alns_targeted
                         if  cds_aln.cds.gene != None
                         and cds_aln.cds.product != None]

    # ------------------- CDSs filtered and ready to be analyzed ------------------- #

    # Number of targeted CDSs
    print ( "Targeted CDSs: " + str(len(cds_alns_targeted)) )

    '''
    print ("Sorting CDSs: stddev/mean")
    cds_alns_sorted = sorted(cds_alns_targeted, 
                             key=lambda cds_aln: cds_aln.get_std_over_mean(),
                             reverse=False)

    # TODO: Here I should somehow determine which CDSs are "expressed", and which are not?

    # Write to file stuff(gene, protein_id) for each cds_aln
    print("Writing data to file")
    path = args.export_path
    f = open(path, 'w')
    
    for cds_aln in cds_alns_targeted:
        gene        = cds_aln.cds.gene
        product     = cds_aln.cds.product
        protein_id  = cds_aln.cds.protein_id

        f.write("{0} {1}\n".format(gene, protein_id))
    f.close()

    print("Done")
    '''
    
    # -------------------- #
    

    '''
    # Analyse those CDSs
    print ( "Targeted CDSs: " + str(len(cds_alns_targeted)) )

    # See the mean length of CDS
    mean_cds_length = 0
    no_locs_num = 0
    for cds_aln in cds_alns_targeted:
        try:
            loc_length = cds_aln.get_cds_length()
            mean_cds_length += loc_length
        except:
            no_locs_num += 1
    # Get mean
    mean_cds_length /= float(len(cds_alns_targeted))

    print("---------------------------------------------")
    print("Mean CDS length: " + str(mean_cds_length))
    print("Nones: " + str(no_locs_num))
            
    print("---------------------------------------------")
    '''


    '''
    # Create folder where data about CDSs will be stored
    if not os.path.exists(args.export_path):
        os.makedirs(args.export_path)

    # Export some amount of best CDSs
    i = 1
    for cds_aln in cds_alns_sorted:
        filename = "cds_" + str(i) + ".txt"
        coverage_path = os.path.join(args.export_path, filename)

        print(str(i) + ": " + str(cds_aln.get_std_over_mean()))

        cds_aln.coverage_to_file(coverage_path)

        if i == 50: # TODO: Define this somehow as parameter
            break

        i += 1

    # Load CDS container
    '''

    # Analyse stuff
    # print("Analysing stuff!")

    '''
def main():

    # Input arguments
    argparser   = ArgParser()
    args        = argparser.parse_args()

    # Access database
    dataAccess = DataAccess(args)

    # ------------------ #

    print '1. Loading tax tree...'
    start = time.time()

    tax_tree = TaxTree()

    end = time.time()
    print("done: {0:.2f} sec".format(end - start))

    # ------------------ #

    print '2. Loading alignment file...'
    start = time.time()

    read_container = ReadContainer()
    read_container.load_alignment_data(args.alignment_file)
    #---SET TAXIDS FOR ALL ALIGNMENTS--#
    read_container.set_taxids(dataAccess)

    end = time.time()
    print("done: {0:.2f} sec".format(end - start))

    # ------------------ #

    # Create folder if does not exist
    if not os.path.exists(args.export_folder):
        os.makedirs(args.export_folder)
    # File for data analysis summary
    summary_path = os.path.join(args.export_folder, "CDSs_summary.txt")
    cds_summary = open(summary_path, 'w')

    if args.remove_host:
        print "Removing host..."
        start = time.time()

        #------- FILTER HOST READS -------#
        #print '3. Filtering host reads & alignments...'
        new_reads = host_filter.filter_potential_host_reads(
            read_container.fetch_all_reads(format=list),
            tax_tree.tax2relevantTax,
            tax_tree.potential_hosts,
            #delete_host_alignments =
            True,
            #filter_unassigned =
            True,
            #unassigned_taxid=
            -1,
            host_filter.perc_of_host_alignments_larger_than)

        dataAccess.clear_cache()    # deletes gi2taxid cache

        reads_with_no_host_alignments = host_filter.filter_potential_hosts_alignments(
            new_reads,
            tax_tree.tax2relevantTax,
            tax_tree.potential_hosts,
            True,   # delete host alignments
            True,   # filter unassigned
            -1)     # unassigned taxid

        read_count          = len(read_container.fetch_all_reads(format=list))
        host_read_count     = read_count - len(reads_with_no_host_alignments)
        non_host_read_count = read_count - host_read_count
        
        cds_summary.write("total   : {0:8d}\n".format(read_count))
        cds_summary.write("host    : {0:8d} {1:.2f}\n".format(host_read_count, 
                                      host_read_count / float(read_count)
                                      ))
        cds_summary.write("non-host: {0:8d} {1:.2f}\n".format(non_host_read_count, 
                                      non_host_read_count / float(read_count)
                                      ))
        # Set host-free reads
        read_container.set_new_reads(reads_with_no_host_alignments)

        end = time.time()
        print("done: {0:.2f} sec".format(end - start))

    #------- LOAD ALL RECORDS   -------#

    print '4. Loading referenced records...'
    start = time.time()

    record_container = RecordContainer()
    record_container.set_db_access(dataAccess)
    record_container.populate(read_container.fetch_all_reads_versions(), table='cds')

    end = time.time()
    print("done: {0:.2f} sec".format(end - start))

    #-- MAP ALIGNMENTS TO GENES   -----#

    print '5. Mapping alignments to genes...'
    start = time.time()

    read_container.populate_cdss(record_container)

    end = time.time()
    print("done: {0:.2f} sec".format(end - start))

    #- RECORD ALL ALIGNEMENTS TO GENE -#

    print '6. Populating CDS container...'
    start = time.time()

    cds_aln_container = CdsAlnContainer()
    cds_aln_container.populate(read_container.fetch_all_reads(format=list))

    end = time.time()
    print("done: {0:.2f} sec".format(end - start))

    # ------------------------------- #

    print 'Sorting CDSs ...DISABLED'
    start = time.time()

    # Sort CDSs by their "good looks"!
    cds_alns = cds_aln_container.fetch_all_cds_alns(format=list)
    '''
    cds_alns = sorted(cds_alns,
                    key=lambda cds_aln: cds_aln.get_std_over_mean(),
                    reverse=False)
    '''
    end = time.time()
    print("done: {0:.2f} sec".format(end - start))

    # ------------------------------- #

    '''
    print "Exporting phase 0 - all CDSs..."
    export_CDS_stats_data(cds_alns, args.export_folder, "0_all_CDSs.txt")
    print "done"
    '''

    # Count Nones in cds_alns
    nones = count_nones(cds_alns)
    cds_summary.write("\n")
    cds_summary.write("gene None       : {0}\n".format(nones['gene']))
    cds_summary.write("protein_id  None: {0}\n".format(nones['protein_id']))
    cds_summary.write("product  None   : {0}\n".format(nones['product']))
    cds_summary.write("\n")

    cds_summary.write("CDSs all: {0}\n".format(len(cds_alns)))

    print 'Filtering valid CDSs...'
    start = time.time()

    # Remove CDSs with too low mean coverage value or length
    min_mean_coverage   = 0
    min_length          = 0
    cds_alns_targeted = [cds_aln for cds_aln in cds_alns 
                         # Filters
                         if cds_aln.get_cds_length() > min_length
                         and cds_aln.get_mean_coverage() > min_mean_coverage]

    # Remove CDSs with no gene/product
    cds_alns_targeted = [cds_aln for cds_aln in cds_alns_targeted
                         if cds_aln.cds.product is not None]
                         #if  cds_aln.cds.gene != None
                         #and cds_aln.cds.product != None]

    end = time.time()
    print("done: {0:.2f} sec".format(end - start))

    # All valid CDSs - Output coverage/length histogram data
    print "Exporting phase 1 - all CDSs..."
    start = time.time()

    export_CDS_stats_data(cds_alns_targeted, args.export_folder, "1_all_valid_CDSs.txt")

    end = time.time()
    print("done: {0:.2f} sec".format(end - start))

    # ------------------- CDSs filtered and ready to be analyzed ------------------- #

    print 'Extracting ribosomal CDSs...'
    # Number of targeted CDSs
    cds_summary.write("CDSs valid: {0}\n".format(len(cds_alns_targeted)))

    cds_alns_ribosomal = []
    for cds_aln in cds_alns_targeted:

        # If has word "ribosomal" in name, store coverage data for graph
        gene        = cds_aln.cds.gene
        product     = cds_aln.cds.product
        protein_id  = cds_aln.cds.protein_id

        if is_ribosomal(product):
            #print("{0} {1} {2}\n".format(gene, protein_id, product))
            cds_alns_ribosomal.append(cds_aln)

    print 'done'
    # ------------------- Ribosomal CDSs acquired! --------------------- #

    print 'Analysing ribosomals...'

    # Extract interesting data
    # Mean coverage, max coverage
    mm_cov  = 0
    max_cov = 0
    for cds_aln in cds_alns_ribosomal:
        mean_cov = cds_aln.get_mean_coverage()
        mm_cov += mean_cov
        max_cov = max(max_cov, mean_cov)
    if mm_cov > 0:
        mm_cov /= len(cds_alns_ribosomal)

    cds_summary.write("ribosomals all {0}\n".format(len(cds_alns_ribosomal)))
    cds_summary.write("mean coverage: {0}\n".format(mm_cov))
    cds_summary.write("max coverage : {0}\n".format(max_cov))
    print 'done'

    # Ribosomal CDSs only - Output coverage/length histogram
    print "Exporting phase 2 - ribosomal CDSs only..."
    export_CDS_stats_data(cds_alns_ribosomal, args.export_folder, "2_ribosomal_CDSs.txt")
    print "done"

    # ------------------- Making biological sense - choosing CDSs -------------------- #

    print 'Filtering under-average ribosomals...'
    # NOTE: take length into consideration?
    cds_alns_ribosomal = [cds_aln for cds_aln in cds_alns_ribosomal
                         # Filters
                         if cds_aln.get_mean_coverage() > mm_cov]
    print 'done'
    cds_summary.write("ribosomals over-mean: {0}\n".format(len(cds_alns_ribosomal)))
    cds_summary.close()

    print 'Phase 3 - filtered ribosomal CDSs...'
    export_CDS_stats_data(cds_alns_ribosomal, args.export_folder, "3_ribosomal_CDSs_filtered.txt")
    print 'done'
    
    # Store charts cov data - if selected so
    if args.export_charts:
        print "Exporting chart coverage data..."
        export_CDS_graph_data(cds_alns_ribosomal, args.export_charts)
        print "done."

    # --------------------- I have chosen CDSs - determine species and analyse ------------------------ #

    # Species level resolution
    # See which species are present - dump ones with not enough CDSs
    # NOTE: So far done in determine_species_by_ribosomals.py

    CDS_count   = {}    # Count CDSs of each species
    species_set = set() # Get estimated tax_ids
    for cds_aln in cds_alns_ribosomal:
        tax_id = cds_aln.cds.taxon

        # Put each tax_id up to the "species" level
        tax_id_species = tax_tree.get_parent_with_rank(tax_id, 'species')

        species_set.add(tax_id_species)
        CDS_count[tax_id_species] = CDS_count.get(tax_id_species, 0) + 1

    # Get reported CDSs ids
    reported_CDS_ids = set()
    for cds_aln in cds_alns_ribosomal:
        reported_CDS_ids.add(cds_aln.cds.id)

    # ------------ Read assignment analysis -------------- #

    print "Read assignment analysis..."

    reads = read_container.fetch_all_reads(format=list)
    assignment_analysis(species_set, reads, tax_tree, args.export_folder, CDS_count)
def main():
    '''
    Script to identify analyse ribosomal genes expressed.
    '''

    # Input arguments
    argparser   = ArgParser()
    args        = argparser.parse_args()

    # Access database
    dataAccess = DataAccess(args)

    #print '1. Loading tax tree...'
    tax_tree = TaxTree()
    #print 'done.'

    #print '2. Loading alignment file...'
    read_container = ReadContainer()
    read_container.load_alignment_data(args.alignment_file)
    #---SET TAXIDS FOR ALL ALIGNMENTS--#
    read_container.set_taxids(dataAccess)
    #print 'done'

    '''
    # TODO: Here i should recognize host reads!
    #------- FILTER HOST READS -------#
    #print '3. Filtering host reads & alignments...'
    new_reads = host_filter.filter_potential_host_reads(
        read_container.fetch_all_reads(format=list),
        tax_tree.tax2relevantTax,
        tax_tree.potential_hosts,
        #delete_host_alignments =
        True,
        #filter_unassigned =
        True,
        #unassigned_taxid=
        -1,
        host_filter.perc_of_host_alignments_larger_than)

    dataAccess.clear_cache()    # deletes gi2taxid cache

    reads_with_no_host_alignments = host_filter.filter_potential_hosts_alignments(
        new_reads,
        tax_tree.tax2relevantTax,
        tax_tree.potential_hosts,
        True,   # delete host alignments
        True,   # filter unassigned
        -1)     # unassigned taxid

    read_count          = len(read_container.fetch_all_reads(format=list))
    host_read_count     = read_count - len(reads_with_no_host_alignments)
    non_host_read_count = read_count - host_read_count
    
    print ("total   : {0:8d}".format(read_count))
    print ("host    : {0:8d} {1:.2f}".format(host_read_count, 
                                  host_read_count / float(read_count)
                                  ))
    print ("non-host: {0:8d} {1:.2f}".format(non_host_read_count, 
                                  non_host_read_count / float(read_count)
                                  ))
    print
    read_container.set_new_reads(reads_with_no_host_alignments)
    '''

    # ------------------------------------- #

    #----------------------------------#
    #------- LOAD ALL RECORDS   -------#
    #print '4. Loading referenced records...'
    record_container = RecordContainer()
    record_container.set_db_access(dataAccess)
    record_container.populate(read_container.fetch_all_reads_versions(), table='cds')
    #print 'done'
    #----------------------------------#
    #-- MAP ALIGNMENTS TO GENES   -----#
    #print '5. Mapping alignments to genes...'
    read_container.populate_cdss(record_container)
    #----------------------------------#
    #- RECORD ALL ALIGNEMENTS TO GENE -#
    #print '6. Populating CDS container...'
    cds_aln_container = CdsAlnContainer()
    cds_aln_container.populate(read_container.fetch_all_reads(format=list))
    #print 'done'

    #print("Loaded CDS container")
    
    # Take only CDSs of given tax_id
    # Remove CDSs with too low mean coverage value
    min_mean_coverage   = 0
    min_length          = 0

    cds_alns = cds_aln_container.fetch_all_cds_alns(format=list)
    print ( "CDSs all  : " + str(len(cds_alns)) )

    cds_alns_targeted = [cds_aln for cds_aln in cds_alns 
                         # Filters
                         if cds_aln.get_cds_length() > min_length
                         and cds_aln.get_mean_coverage() > min_mean_coverage]

    # Remove CDSs with no gene/product
    cds_alns_targeted = [cds_aln for cds_aln in cds_alns_targeted
                         if  cds_aln.cds.gene != None
                         and cds_aln.cds.product != None]

    # ------------------- CDSs filtered and ready to be analyzed ------------------- #

    # Number of targeted CDSs
    print ( "CDSs valid: " + str(len(cds_alns_targeted)) )
    print

    cds_alns_ribosomal = []
    for cds_aln in cds_alns_targeted:

        # If has word "ribosomal" in name, store coverage data for graph
        gene        = cds_aln.cds.gene
        product     = cds_aln.cds.product
        protein_id  = cds_aln.cds.protein_id

        if "ribosomal" in product:
            #print("{0} {1} {2}\n".format(gene, protein_id, product))
            cds_alns_ribosomal.append(cds_aln)

    # ------------------- Ribosomal CDSs acquired! --------------------- #
    # Sort it!
    cds_alns_ribosomal = sorted(cds_alns_ribosomal, 
                             key=lambda cds_aln: cds_aln.get_std_over_mean(),
                             reverse=False)

    # Extract interesting data
    # Mean coverage, max coverage
    mm_cov  = 0
    max_cov = 0
    for cds_aln in cds_alns_ribosomal:
        mean_cov = cds_aln.get_mean_coverage()
        mm_cov += mean_cov
        max_cov = max(max_cov, mean_cov)
    if mm_cov > 0:
        mm_cov /= len(cds_alns_ribosomal)

    # Print
    print("ribosomals: " + str(len(cds_alns_ribosomal)))
    print("mean coverage: " + str(mm_cov))
    print("max coverage : {0}".format(max_cov))
    print
    for cds_aln in cds_alns_ribosomal:
        gene        = cds_aln.cds.gene
        product     = cds_aln.cds.product
        protein_id  = cds_aln.cds.protein_id

        taxon       = cds_aln.cds.taxon
        name        = tax_tree.nodes[taxon].organism_name
        print("{0:4} {1:10} {2:50} {3:10d} {4:60}".format(gene, protein_id, product, taxon, name))

    # Store graph data
    export_CDS_graph_data(cds_alns_ribosomal, args.export_path)

    '''