Exemplo n.º 1
0
class TestIntervalTrack(unittest.TestCase):

    def setUp(self):
        filename = mktemp(prefix="tmp", suffix=".h5")
        self.filename = filename
        self.refs = (('gene1', 1200), ('gene2', 200))
        self.tf = TrackFactory(self.filename, 'w', refs=self.refs) 
        self.interval_dtype = get_base_dtype_fields() + [('id', 'u4')]
        self.t = self.tf.create_track('intervals1', IntervalTrack,
                                      dtype=self.interval_dtype)
        
    def tearDown(self):
        self.tf.close()
        if os.path.exists(self.filename):
            os.remove(self.filename)

    def test_save_load_index(self):
        """create a table, add intervals, save the index, and 
        ensure the index persists when reloading the interval table
        """
        mytrack = self.t
        ref = 'gene1'
        intervals = []
        for i in xrange(100):
            start = np.random.randint(0, 1150)
            end = start + np.random.randint(1, 50)
            interval = Interval(ref, start, end, i)
            intervals.append(interval)
            mytrack.add(interval)
        # create and save index
        mytrack.index(persist=True)
        # save the index table
        tblcopy = mytrack.indexes['gene1'].tree.tbl.copy()
        tblroot = mytrack.indexes['gene1'].tree.root_id
        # close and reopen the file
        self.tf.close()
        self.tf = TrackFactory(self.filename, 'r', refs=self.refs)
        mytrack = self.tf.get_track('intervals1')
        # compare indexes
        self.assertTrue(np.all(tblcopy == mytrack.indexes['gene1'].tree.tbl))
        self.assertEqual(tblroot, mytrack.indexes['gene1'].tree.root_id)

    def testNonOverlapping(self):
        mytrack = self.t
        # insert some non-overlapping intervals
        intervals = []
        for i,start in enumerate(xrange(0, 900, 50)):
            ref = 'gene1'
            end = start + 10
            interval = Interval(ref, start, end, i)
            intervals.append(interval)
            mytrack.add(interval)        
        # check intervals
        for i,interval in enumerate(intervals):
            hits = mytrack.intersect(interval.ref, interval.start, interval.end)
            self.assertTrue([x['id'] for x in hits] == [i])
        # create interval tree index
        mytrack.index()
        for i,interval in enumerate(intervals):
            hits = mytrack.intersect(interval.ref, interval.start, interval.end)
            self.assertTrue([x['id'] for x in hits] == [i])
 
    def testOverlappingIntervals(self):
        mytrack = self.t
        # insert overlapping intervals
        intervals = []
        for i,start in enumerate(xrange(0, 900, 50)):
            ref = 'gene1'
            end = start + 60
            interval = Interval(ref, start, end, i)
            intervals.append(interval)
            mytrack.add(interval)
        # check intervals        
        self.assertTrue([x['id'] for x in mytrack.intersect(intervals[0].ref, intervals[0].start, intervals[0].end)] == [0,1])
        for i in xrange(1, len(intervals) - 1):
            interval = intervals[i]
            self.assertTrue([x['id'] for x in mytrack.intersect(interval.ref, interval.start, interval.end)] == [i-1, i, i+1])
        self.assertTrue([x['id'] for x in mytrack.intersect(intervals[-1].ref, intervals[-1].start, intervals[-1].end)] == [len(intervals) - 2, len(intervals)-1])
        # create interval tree index
        mytrack.index()
        self.assertTrue([x['id'] for x in mytrack.intersect(intervals[0].ref, intervals[0].start, intervals[0].end)] == [0,1])
        for i in xrange(1, len(intervals) - 1):
            interval = intervals[i]
            self.assertTrue([x['id'] for x in mytrack.intersect(interval.ref, interval.start, interval.end)] == [i-1, i, i+1])
        self.assertTrue([x['id'] for x in mytrack.intersect(intervals[-1].ref, intervals[-1].start, intervals[-1].end)] == [len(intervals) - 2, len(intervals)-1])

    def testBeforeAfter(self):
        mytrack = self.t
        # insert overlapping intervals
        intervals = []
        for i,start in enumerate(xrange(0, 1000, 50)):
            ref = 'gene1'
            end = start + 10
            interval = Interval(ref, start, end, i)            
            intervals.append(interval)
            mytrack.add(interval)
        # check intervals
        mytrack.index()
        for i,interval in enumerate(intervals):
            res = mytrack.before('gene1', interval.start, num_intervals=len(intervals), max_dist=1200)
            if [x['id'] for x in res] != range(i-1, -1, -1):
                print res, range(i-1, -1, -1)
            self.assertTrue([x['id'] for x in res] == range(i-1, -1, -1))
            res = mytrack.after('gene1', interval.end, num_intervals=len(intervals), max_dist=1200)
            if [x['id'] for x in res] != range(i+1, len(intervals)):
                print res, range(i-1, -1, -1)
            self.assertTrue([x['id'] for x in res] == range(i+1, len(intervals)))

    def _boundary_checks(self, mytrack):
        # test left boundary
        self.assertTrue(len(mytrack.before('gene1', 199, 1, 2000)) == 0) 
        self.assertTrue(len(mytrack.before('gene1', 200, 1, 2000)) == 1) 
        self.assertTrue(len(mytrack.before('gene1', 201, 1, 2000)) == 1) 
        self.assertTrue(len(mytrack.before('gene1', 202, 1, 2000)) == 1) 
        # test right boundary
        self.assertTrue(len(mytrack.after('gene1', 101, 1, 2000)) == 0) 
        self.assertTrue(len(mytrack.after('gene1', 100, 1, 2000)) == 0) 
        self.assertTrue(len(mytrack.after('gene1', 99, 1, 2000)) == 1) 
        self.assertTrue(len(mytrack.after('gene1', 98, 1, 2000)) == 1) 
        # test left max dist
        self.assertTrue(len(mytrack.before('gene1', 200, 1, 1)) == 1) 
        self.assertTrue(len(mytrack.before('gene1', 201, 1, 1)) == 0) 
        self.assertTrue(len(mytrack.before('gene1', 202, 1, 1)) == 0) 
        self.assertTrue(len(mytrack.before('gene1', 300, 1, 102)) == 1) 
        self.assertTrue(len(mytrack.before('gene1', 300, 1, 101)) == 1) 
        self.assertTrue(len(mytrack.before('gene1', 300, 1, 100)) == 0) 
        self.assertTrue(len(mytrack.before('gene1', 300, 1, 99)) == 0) 
        # test right max dist
        self.assertTrue(len(mytrack.after('gene1', 99, 1, 1)) == 1) 
        self.assertTrue(len(mytrack.after('gene1', 98, 1, 1)) == 0) 
        self.assertTrue(len(mytrack.after('gene1', 98, 1, 2)) == 1) 
        self.assertTrue(len(mytrack.after('gene1', 0, 1, 100)) == 1) 
        self.assertTrue(len(mytrack.after('gene1', 0, 1, 99)) == 0)

    def testBeforeAfterBoundaries(self):
        mytrack = self.t
        # insert overlapping intervals
        mytrack.add(Interval('gene1', 100, 200, 0))
        self._boundary_checks(mytrack)
        mytrack.index()
        self._boundary_checks(mytrack)
 
    def _intersect_checks(self, mytrack):
        self.assertTrue(len(mytrack.intersect('gene1', 0, 100)) == 0)
        self.assertTrue(len(mytrack.intersect('gene1', 0, 101)) == 1)
        self.assertTrue(len(mytrack.intersect('gene1', 200, 210)) == 0)
        self.assertTrue(len(mytrack.intersect('gene1', 199, 210)) == 1)

    def testIntersectBoundaries(self):
        mytrack = self.t
        # insert overlapping intervals
        mytrack.add(Interval('gene1', 100, 200, 0))
        self._intersect_checks(mytrack)
        mytrack.index()
        self._intersect_checks(mytrack)

    def testMultipleReferences(self):
        mytrack = self.t
        # add intervals to different references
        ref = 'gene1'        
        for i in xrange(0, 10):
            mytrack.add(Interval(ref, i, i+10, i))
            if ref == 'gene1':
                ref = 'gene2'
            else:
                ref = 'gene1'
        ids = [r['id'] for r in mytrack.intersect('gene1', 0, 20)]
        self.assertTrue(set(ids) == set(range(0, 10, 2)))
        ids = [r['id'] for r in mytrack.intersect('gene2', 0, 20)]
        self.assertTrue(set(ids) == set(range(1, 10, 2)))

    def testAdd(self):
        mytrack = self.t
        myrow = np.zeros(1, dtype=self.interval_dtype)[0]
        intervals = []
        for i,start in enumerate(xrange(0, 1000, 50)):
            myrow['ref'] = 'gene1'
            myrow['start'] = start
            myrow['end'] = start + 10
            myrow['id'] = i
            ref = 'gene1'
            end = start + 10
            interval = Interval(ref, start, end, i)
            mytrack.add(myrow)
            intervals.append(interval)
        for i,interval in enumerate(intervals):
            res = mytrack[(interval.ref, interval.start, interval.end)]
            self.assertTrue(len(res) == 1)
            hit = res[0]
            self.assertTrue((interval.ref, interval.start, interval.end, interval.id) ==
                            (hit['ref'], hit['start'], hit['end'], hit['id']))
Exemplo n.º 2
0
class TestSequenceTrack(unittest.TestCase):

    def setUp(self):
        filename = mktemp(prefix="tmp", suffix=".h5")
        self.filename = filename
        self.refs = (('chr1', 100), ('chr2', 200))
        self.tf = TrackFactory(self.filename, 'w', refs=self.refs) 

    def tearDown(self):
        self.tf.close()
        if os.path.exists(self.filename):
            os.remove(self.filename)

    def test_bpb(self):
        for bpb in (2,3,4):
            tname = 't%d' % (bpb)
            t = self.tf.create_track(tname, SequenceTrack, bpb=bpb)
            # write a cycle of bases one at a time and then
            # check them
            dna_iter = itertools.cycle('AATTTGCCTGC')
            for i in xrange(100):
                t[('chr1', i)] = dna_iter.next()
            dna_iter = itertools.cycle('AATTTGCCTGC')
            for i in xrange(100):
                self.assertEqual(t[('chr1', i)], dna_iter.next())
            # write chunks of bases and check them
            dna_iter = itertools.repeat('CACATGTAGAGCT')
            for i in xrange(0, 195, 13):
                t[('chr2', i, i+13)] = dna_iter.next()            
            for i in xrange(0, 195, 13):            
                self.assertEqual(t[('chr2', i, i+13)], dna_iter.next()) 
            # overwrite a range of bases without affecting others
            t[('chr2', 99, 102)] = 'TTT'
            self.assertEqual(t[('chr2', 94,107)], 'ATGTATTTCTCAC')

    def test_bases(self):
        t2 = self.tf.create_track("t2", SequenceTrack, bpb=2)
        t3 = self.tf.create_track("t3", SequenceTrack, bpb=3)
        t4 = self.tf.create_track("t4", SequenceTrack, bpb=4)
        # test default bases
        self.assertEqual(t2[('chr1', 35, 45)], 'AAAAAAAAAA')
        self.assertEqual(t3[('chr1', 35, 45)], 'NNNNNNNNNN')
        self.assertEqual(t4[('chr1', 35, 45)], 'NNNNNNNNNN')
        # test storing lower case bases and 'N's
        seq = 'NattgcgcNN'
        t2[('chr1', 35, 45)] = seq
        t3[('chr1', 35, 45)] = seq
        t4[('chr1', 35, 45)] = seq
        self.assertEqual(t2[('chr1', 35, 45)], seq.upper().replace('N', 'A'))
        self.assertEqual(t3[('chr1', 35, 45)], seq.upper())
        self.assertEqual(t4[('chr1', 35, 45)], seq)

    def test_fromfasta(self):
        # read/write to chromosome
        seqlist = [">chr2",
                   "ATGCAGTGAC",
                   "GTGACGAGAG",
                   "TGTAGAGAGA",
                   "GTGATGTATG",
                   "GGGGGGGGGG",
                   "GGGCCCCCGC",
                   "CCCCCCCCCC"
                   "CCCCCCCCCC",
                   "ATGAAAAAAG",
                   "TTGCCAAAAA",
                   "AAAAAAAAAA",
                   "AAAAAAAAAA",
                   "AAAGTGATTT",
                   "TTTTTTGCCG",
                   "AGCGCGCGCG"]
        fullseq = ''.join(seqlist[1:])
        t2 = self.tf.create_track("t2", SequenceTrack, bpb=2)
        t2.fromfasta(iter(seqlist))
        self.assertEqual(fullseq, t2[('chr2', 0, len(fullseq))])
        # insert at different places
        for startpos in (0, 10, 20, 30, 40):
            seqlist[0] = ">chr2:%d" % (startpos)
            t2.fromfasta(iter(seqlist))
            self.assertEqual(fullseq, t2[('chr2', startpos, startpos + len(fullseq))])