コード例 #1
0
def test_span():
    e = IntervalTree()
    assert e.span() == 0

    t = trees['ivs1']()
    assert t.span() == t.end() - t.begin()
    assert t.span() == 14
コード例 #2
0
ファイル: query_test.py プロジェクト: chaimleib/intervaltree
def test_span():
    e = IntervalTree()
    assert e.span() == 0

    t = IntervalTree.from_tuples(data.ivs1.data)
    assert t.span() == t.end() - t.begin()
    assert t.span() == 14
コード例 #3
0
ファイル: alignStats.py プロジェクト: lazappi/binf-scripts
def load_gtf(gtf_path):
    """
    Load a GTF annotation and create an index using IntervalTrees.

    Args:
        gtf_path: Path to the GTF file to load.

    Returns:
        Dictionary containing IntervalTree indexes of the annotation.
    """

    gtf_index = defaultdict()
    with open(gtf_path) as gtf_file:
        for line in gtf_file:
            if not line.startswith("#"):
                entry = line.split("\t")
                entry_addition = entry[8]
                entry_addition = entry_addition.split(";")
                entry_addition = entry_addition[0].split(" ")
                gene_id = entry_addition[1]

                feature = entry[2]
                #TYPE(Gene, exon etc.), START, END, STRAND, gene_ID
                info = [feature, entry[3], entry[4], entry[6], gene_id]

                #Build GTF INDEX
                if feature != "" and entry[3] != entry[4]:
                    if entry[0] in gtf_index:
                        index = gtf_index[entry[0]]
                    else:
                        index = IntervalTree()
                    index.addi(int(info[1]), int(info[2]), info)
                    gtf_index[entry[0]] = index

    return gtf_index
コード例 #4
0
def test_merge_equals_with_dupes():
    t = IntervalTree.from_tuples(data.ivs1.data)
    orig = IntervalTree.from_tuples(data.ivs1.data)
    assert orig == t

    # one dupe
    assert t.containsi(4, 7, '[4,7)')
    t.addi(4, 7, 'foo')
    assert len(t) == len(orig) + 1
    assert orig != t

    t.merge_equals()
    t.verify()
    assert t != orig
    assert t.containsi(4, 7)
    assert not t.containsi(4, 7, 'foo')
    assert not t.containsi(4, 7, '[4,7)')

    # two dupes
    t = IntervalTree.from_tuples(data.ivs1.data)
    t.addi(4, 7, 'foo')
    assert t.containsi(10, 12, '[10,12)')
    t.addi(10, 12, 'bar')
    assert len(t) == len(orig) + 2
    assert t != orig

    t.merge_equals()
    t.verify()
    assert t != orig
    assert t.containsi(4, 7)
    assert not t.containsi(4, 7, 'foo')
    assert not t.containsi(4, 7, '[4,7)')
    assert t.containsi(10, 12)
    assert not t.containsi(10, 12, 'bar')
    assert not t.containsi(10, 12, '[10,12)')
コード例 #5
0
ファイル: interval.py プロジェクト: rebeling/anchorman
def to_intervaltree(data, t=None):
    """Create an intervaltree of all elements (elements, units, ...).
    :param t:
    :param data:
    """

    if t is None:
        t = IntervalTree()

    overlaps = []
    existing_values = []
    existing_a_tags = []
    for token, slices, _type in data:
        _from, _to = slices
        t[_from:_to] = (token, _type)

        if _type[0] == 'restricted_area':
            overlaps.append((_from, _to, token, _type))
            a, b = token
            if a == 'a':
                existing_values.append(b)
                existing_a_tags.append((_from, _to))

    # remove all elements in restricted_areas
    if overlaps:
        for begin, end, token, _type in overlaps:
            t.remove_envelop(begin, end)

    return t, existing_values, existing_a_tags
コード例 #6
0
ファイル: init_test.py プロジェクト: ProgVal/intervaltree
def test_empty_init():
    tree = IntervalTree()
    tree.verify()
    assert not tree
    assert len(tree) == 0
    assert list(tree) == []
    assert tree.is_empty()
コード例 #7
0
def test_merge_equals_reducer_wo_initializer():
    def reducer(old, new):
        return "%s, %s" % (old, new)
    # empty tree
    e = IntervalTree()
    e.merge_equals(data_reducer=reducer)
    e.verify()
    assert not e

    # One Interval in tree, no change
    o = IntervalTree.from_tuples([(1, 2, 'hello')])
    o.merge_equals(data_reducer=reducer)
    o.verify()
    assert len(o) == 1
    assert sorted(o) == [Interval(1, 2, 'hello')]

    # many Intervals in tree, no change
    t = IntervalTree.from_tuples(data.ivs1.data)
    orig = IntervalTree.from_tuples(data.ivs1.data)
    t.merge_equals(data_reducer=reducer)
    t.verify()
    assert len(t) == len(orig)
    assert t == orig

    # many Intervals in tree, with change
    t = IntervalTree.from_tuples(data.ivs1.data)
    orig = IntervalTree.from_tuples(data.ivs1.data)
    t.addi(4, 7, 'foo')
    t.merge_equals(data_reducer=reducer)
    t.verify()
    assert len(t) == len(orig)
    assert t != orig
    assert not t.containsi(4, 7, 'foo')
    assert not t.containsi(4, 7, '[4,7)')
    assert t.containsi(4, 7, '[4,7), foo')
コード例 #8
0
def load_GTF(gtf_file):

    gtf_index = defaultdict()
    with open(gtf_file) as f:
        for line in f:
             if (not line.startswith("#")):
                 entry = line.split("\t")
                 entry_addition = entry[8]
                 entry_addition = entry_addition.split(";")
                 entry_addition = entry_addition[0].split(" ")
                 gene_id = entry_addition[1]
               
                 type = entry[2] 
                 #TYPE(Gene, exon etc.), START, END, STRAND, gene_ID
                 info = [type, entry[3], entry[4], entry[6], gene_id]
        
                 #Build GTF INDEX
                 if (type != "" and entry[3]!= entry[4]):
                    index = IntervalTree()
                    if (entry[0] in gtf_index):
                         index = gtf_index[entry[0]]
                    index.addi(int(info[1]),int(info[2]),info) 
                    gtf_index[entry[0]] = index

    return (gtf_index)
コード例 #9
0
def test_invalid_update():
    t = IntervalTree()

    with pytest.raises(ValueError):
        t.update([Interval(1, 0)])

    with pytest.raises(ValueError):
        t.update([Interval(1, 1)])
コード例 #10
0
 def test_add_descending(self, ivs):
     if self.verbose:
         pbar = ProgressBar(len(ivs))
     t = IntervalTree()
     for iv in sorted(ivs, reverse=True):
         t.add(iv)
         if self.verbose: pbar()
     return t
コード例 #11
0
class RepeatDb(object):

    def __init__(self, assembly, contig, start, end):
        """
        Given a range on a contig, get all the repeats overlapping that range.
        
        Keeps an IntervalTree of element names, and a Counter from element
        name to number of that element in the range.
        
        No protection against SQL injection.
        
        """
        
        # Make the interval tree
        self.tree = IntervalTree()
        
        # Make a counter for repeats with a certain name
        self.counts = collections.Counter()
        
        command = ["hgsql", "-e", "select repName, genoName, genoStart, genoEnd "
            "from {}.rmsk where genoName = '{}' and genoStart > '{}' "
            "and genoEnd < '{}';".format(assembly, contig, start, end)]
        process = subprocess.Popen(command, stdout=subprocess.PIPE)
        
        for parts in itertools.islice(tsv.TsvReader(process.stdout), 1, None):
            # For each line except the first, broken into fields
            
            # Add the item to the tree covering its range. Store the repeat type
            # name as the interval's data.
            self.tree.addi(int(parts[2]), int(parts[3]), parts[0])
            
            # Count it
            self.counts[parts[0]] += 1

    def get_copies(self, contig, pos):
        """
        Given a contig name and a position, estimate the copy number of that
        position in the genome.
        
        Return the number of instances expected (1 for non-repetitive sequence).
        """
        
        # TODO: use contig
        
        # Get the set of overlapping things
        overlaps = self.tree[pos]
        
        # Keep track of the number of copies of the most numerous repeat
        # observed.
        max_copies = 1
        
        for interval in overlaps:
            # For each repeat we are in
            
            # Max in how many copies of it exist
            max_copies = max(max_copies, self.counts[interval.data])
            
        return max_copies
コード例 #12
0
def test_merge_equals_wo_dupes():
    t = IntervalTree.from_tuples(data.ivs1.data)
    orig = IntervalTree.from_tuples(data.ivs1.data)
    assert orig == t

    t.merge_equals()
    t.verify()

    assert orig == t
コード例 #13
0
ファイル: distributor.py プロジェクト: GjjvdBurg/labella.py
    def countIdealOverlaps(self, nodes):
        iTree = IntervalTree()
        for node in nodes:
            iTree.addi(node.idealLeft(), node.idealRight(), data=node)

        for node in nodes:
            overlaps = iTree.search(node.idealLeft(), node.idealRight())
            node.overlaps = [x.data for x in overlaps]
            node.overlapCount = len(overlaps)
コード例 #14
0
ファイル: flash_reader.py プロジェクト: flit/pyOCD
class FlashReaderContext(DebugContext):
    """! @brief Reads flash memory regions from an ELF file instead of the target."""

    def __init__(self, parent, elf):
        super(FlashReaderContext, self).__init__(parent)
        self._elf = elf

        self._build_regions()

    def _build_regions(self):
        self._tree = IntervalTree()
        for sect in [s for s in self._elf.sections if (s.region and s.region.is_flash)]:
            start = sect.start
            length = sect.length
            sect.data # Go ahead and read the data from the file.
            self._tree.addi(start, start + length, sect)
            LOG.debug("created flash section [%x:%x] for section %s", start, start + length, sect.name)

    def read_memory(self, addr, transfer_size=32, now=True):
        length = transfer_size // 8
        matches = self._tree.overlap(addr, addr + length)
        # Must match only one interval (ELF section).
        if len(matches) != 1:
            return self._parent.read_memory(addr, transfer_size, now)
        section = matches.pop().data
        addr -= section.start

        def read_memory_cb():
            LOG.debug("read flash data [%x:%x] from section %s", section.start + addr, section.start + addr  + length, section.name)
            data = section.data[addr:addr + length]
            if transfer_size == 8:
                return data[0]
            elif transfer_size == 16:
                return conversion.byte_list_to_u16le_list(data)[0]
            elif transfer_size == 32:
                return conversion.byte_list_to_u32le_list(data)[0]
            else:
                raise ValueError("invalid transfer_size (%d)" % transfer_size)

        if now:
            return read_memory_cb()
        else:
            return read_memory_cb

    def read_memory_block8(self, addr, size):
        matches = self._tree.overlap(addr, addr + size)
        # Must match only one interval (ELF section).
        if len(matches) != 1:
            return self._parent.read_memory_block8(addr, size)
        section = matches.pop().data
        addr -= section.start
        data = section.data[addr:addr + size]
        LOG.debug("read flash data [%x:%x]", section.start + addr, section.start + addr  + size)
        return list(data)

    def read_memory_block32(self, addr, size):
        return conversion.byte_list_to_u32le_list(self.read_memory_block8(addr, size))
コード例 #15
0
class BorderModel(QObject):
    rangeChanged = pyqtSignal([BorderedRange])

    def __init__(self, parent, color_theme=SolarizedColorTheme):
        super(BorderModel, self).__init__(parent)

        # data structure description:
        # _db is an interval tree that indexes on the start and end of bordered ranges
        # the values are BorderedRange instances.
        # given an index, determining its border is):
        #   intervaltree lookup index in _db (which is O(log <num ranges>) )
        #   iterate containing ranges (worst case, O(<num ranges>), but typically small)
        #     hash lookup on index to fetch border state (which is O(1))
        self._db = IntervalTree()
        self._theme = color_theme

    def border_region(self, begin, end, color=None):
        if color is None:
            color = self._theme.get_accent(len(self._db))
        range = BorderedRange(begin, end, BorderTheme(color), compute_region_border(begin, end))
        # note we use (end + 1) to ensure the entire selection gets captured
        self._db.addi(range.begin, range.end + 1, range)
        self.rangeChanged.emit(range)

    def clear_region(self, begin, end):
        span = end - begin
        to_remove = []
        for r in self._db[begin:end]:
            if r.end - r.begin - 1 == span:
                to_remove.append(r)
        for r in to_remove:
            self._db.removei(r.begin, r.end, r.data)
            self.rangeChanged.emit(r.data)

    def get_border(self, index):
        # ranges is a (potentially empty) list of intervaltree.Interval instances
        # we sort them here from shorted length to longest, because we want
        #    the most specific border
        ranges = sorted(self._db[index], key=lambda r: r.end - r.begin)
        if len(ranges) > 0:
            range = ranges[0].data
            cell = range.cells.get(index, None)
            if cell is None:
                return None
            ret = BorderData(cell.top, cell.bottom, cell.left, cell.right, range.theme)
            return ret
        return None

    def is_index_bordered(self, index):
        return len(self._db[index]) > 0

    def is_region_bordered(self, begin, end):
        span = end - begin
        for range in self._db[begin:end]:
            if range.end - range.begin == span:
                return True
        return False
コード例 #16
0
def test_emptying_partial():
    t = IntervalTree.from_tuples(data.ivs1.data)
    assert t[7:]
    t.remove_overlap(7, t.end())
    assert not t[7:]

    t = IntervalTree.from_tuples(data.ivs1.data)
    assert t[:7]
    t.remove_overlap(t.begin(), 7)
    assert not t[:7]
コード例 #17
0
ファイル: clusterPlots.py プロジェクト: dalexander/PRmm
def plot2C2AScatterTimeSeries(zmwFixture, frameInterval=4096):
    """
    Plot a 2C2A scatter plot for every `frameInterval` frames; overlay
    information about HQRegion and alignment(s), if found in the dataset.
    """
    t = zmwFixture.cameraTrace
    df = pd.DataFrame(np.transpose(t), columns=["C1", "C2"])

    # what is the extent of the data?  force a square perspective so
    # we don't distort the spectral angle
    xmin = ymin = min(df.min())
    xmax = ymax = max(df.max())

    def fracX(frac): return xmin + (xmax - xmin) * frac
    def fracY(frac): return ymin + (ymax - ymin) * frac

    numPanes = int(math.ceil(float(zmwFixture.numFrames) / frameInterval))
    numCols = 6
    numRows = int(math.ceil(float(numPanes) / numCols))
    paneSize = np.array([3, 3])

    figsize = np.array([numCols, numRows]) * paneSize
    fig, ax = plt.subplots(numRows, numCols, sharex=True, sharey=True,
                           figsize=figsize)
    axr = ax.ravel()

    details = "" # TODO
    fig.suptitle("%s\n%s" % (zmwFixture.zmwName, details), fontsize=20)


    alnIntervals = IntervalTree()
    for r in zmwFixture.regions:
        if r.regionType == Region.ALIGNMENT_REGION:
            alnIntervals.addi(r.startFrame, r.endFrame)

    def overlapsAln(frameStart, frameEnd):
        if alnIntervals.search(frameStart, frameEnd):
            return True
        else:
            return False

    for i in xrange(numPanes):
        frameSpan = startFrame, endFrame = i*frameInterval, (i+1)*frameInterval
        axr[i].set_xlim(xmin, xmax)
        axr[i].set_ylim(ymin, ymax)
        axr[i].plot(df.C1[startFrame:endFrame], df.C2[startFrame:endFrame], ".")

        baseSpan = zmwFixture.baseIntervalFromFrames(*frameSpan)
        axr[i].text(fracX(0.6), fracY(0.9), "/%d_%d" %  baseSpan)

        if overlapsAln(*frameSpan):
            axr[i].hlines(fracY(1.0), xmin, xmax, colors=["red"], linewidth=4)


    return axr
コード例 #18
0
ファイル: issue4.py プロジェクト: chaimleib/intervaltree
def test_build_tree():
    pbar = ProgressBar(len(items))

    tree = IntervalTree()
    tree[0:MAX] = None
    for b, e, alloc in items:
        if alloc:
            ivs = tree[b:e]
            assert len(ivs)==1
            iv = ivs.pop()
            assert iv.begin<=b and e<=iv.end
            tree.remove(iv)
            if iv.begin<b:
                tree[iv.begin:b] = None
            if e<iv.end:
                tree[e:iv.end] = None
        else:
            ivs = tree[b:e]
            assert not ivs
            prev = tree[b-1:b]
            assert len(prev) in (0, 1)
            if prev:
                prev = prev.pop()
                b = prev.begin
                tree.remove(prev)
            next = tree[e:e+1]
            assert len(next) in (0, 1)
            if next:
                next = next.pop()
                e = next.end
                tree.remove(next)
            tree[b:e] = None
        pbar()
    tree.verify()
    return tree
コード例 #19
0
def test_add_invalid_interval():
    """
    Ensure that begin < end.
    """
    itree = IntervalTree()
    with pytest.raises(ValueError):
        itree.addi(1, 0)

    with pytest.raises(ValueError):
        itree.addi(1, 1)

    with pytest.raises(ValueError):
        itree[1:0] = "value"

    with pytest.raises(ValueError):
        itree[1:1] = "value"

    with pytest.raises(ValueError):
        itree[1.1:1.05] = "value"

    with pytest.raises(ValueError):
        itree[1.1:1.1] = "value"

    with pytest.raises(ValueError):
        itree.extend([Interval(1, 0)])

    with pytest.raises(ValueError):
        itree.extend([Interval(1, 1)])
コード例 #20
0
ファイル: query_test.py プロジェクト: chaimleib/intervaltree
def test_tree_bounds():
    def assert_tree_bounds(t):
        begin, end, _ = set(t).pop()
        for iv in t:
            if iv.begin < begin: begin = iv.begin
            if iv.end > end: end = iv.end
        assert t.begin() == begin
        assert t.end() == end

    assert_tree_bounds(IntervalTree.from_tuples(data.ivs1.data))
    assert_tree_bounds(IntervalTree.from_tuples(data.ivs2.data))
コード例 #21
0
class ColorModel(QObject):
    rangeChanged = pyqtSignal([ColoredRange])

    def __init__(self, parent, color_theme=SolarizedColorTheme):
        super(ColorModel, self).__init__(parent)
        self._db = IntervalTree()
        self._theme = color_theme

    def color_region(self, begin, end, color=None):
        if color is None:
            color = self._theme.get_accent(len(self._db))
        r = ColoredRange(begin, end, color)
        self.color_range(r)
        return r

    def clear_region(self, begin, end):
        span = end - begin
        to_remove = []
        for r in self._db[begin:end]:
            if r.end - r.begin == span:
                to_remove.append(r)
        for r in to_remove:
            self.clear_range(r.data)

    def color_range(self, range_):
        self._db.addi(range_.begin, range_.end, range_)
        self.rangeChanged.emit(range_)

    def clear_range(self, range_):
        self._db.removei(range_.begin, range_.end, range_)
        self.rangeChanged.emit(range_)

    def get_color(self, index):
        # ranges is a (potentially empty) list of intervaltree.Interval instances
        # we sort them here from shorted length to longest, because we want
        #    the most specific color
        ranges = sorted(self._db[index], key=lambda r: r.end - r.begin)
        if len(ranges) > 0:
            return ranges[0].data.color
        return None

    def get_region_colors(self, begin, end):
        if begin == end:
            results = self._db[begin]
        else:
            results = self._db[begin:end]
        return funcy.pluck_attr("data", results)

    def is_index_colored(self, index):
        return len(self._db[index]) > 0

    def is_region_colored(self, begin, end):
        return len(self._db[begin:end]) > 0
コード例 #22
0
ファイル: query_test.py プロジェクト: chaimleib/intervaltree
def test_partial_get_query():
    def assert_get(t, limit):
        s = set(t)
        assert t[:] == s

        s = set(iv for iv in t if iv.begin < limit)
        assert t[:limit] == s

        s = set(iv for iv in t if iv.end > limit)
        assert t[limit:] == s

    assert_get(IntervalTree.from_tuples(data.ivs1.data), 7)
    assert_get(IntervalTree.from_tuples(data.ivs2.data), -3)
コード例 #23
0
def test_merge_overlaps_gapless():
    # default strict=True
    t = IntervalTree.from_tuples(data.ivs2.data)
    t.merge_overlaps()
    t.verify()
    assert [(iv.begin, iv.end, iv.data) for iv in sorted(t)] == data.ivs2.data

    # strict=False
    t = IntervalTree.from_tuples(data.ivs2.data)
    rng = t.range()
    t.merge_overlaps(strict=False)
    t.verify()
    assert len(t) == 1
    assert t.pop() == rng
コード例 #24
0
def test_brackets_vs_overlap():
    it = IntervalTree()
    it.addi(1, 3, "dude")
    it.addi(2, 4, "sweet")
    it.addi(6, 9, "rad")
    for iobj in it:
        assert it[iobj.begin:iobj.end] == it.overlap(iobj.begin, iobj.end)
コード例 #25
0
ファイル: util.py プロジェクト: jrderuiter/im-fusion
    def from_gtf(
            cls,
            gtf_path,  # type: pathlib.Path
            chromosomes=None,  # type: List[str]
            record_filter=None  # type: Callable[[Any], bool]
    ):  # type: (...) -> TranscriptReference
        """Builds an Reference instance from the given GTF file."""

        # Open gtf file.
        gtf = pysam.TabixFile(native_str(gtf_path), parser=pysam.asGTF())

        if chromosomes is None:
            chromosomes = gtf.contigs

        # Build the trees.
        transcript_trees = {}
        exon_trees = {}

        for chrom in chromosomes:
            # Collect exons and transcripts.
            transcripts = []
            exons = []

            records = gtf.fetch(reference=chrom)

            if record_filter is not None:
                records = (rec for rec in records if record_filter(rec))

            for record in records:
                if record.feature == 'transcript':
                    transcripts.append(cls._record_to_transcript(record))
                elif record.feature == 'exon':
                    exons.append(cls._record_to_exon(record))

            # Build transcript lookup tree.
            transcript_trees[chrom] = IntervalTree.from_tuples(
                (tr.start, tr.end, tr) for tr in transcripts)

            # Build exon lookup tree.
            keyfunc = lambda rec: rec.transcript_id

            exons = sorted(exons, key=keyfunc)
            grouped = itertools.groupby(exons, key=keyfunc)

            for tr_id, grp in grouped:
                exon_trees[tr_id] = IntervalTree.from_tuples(
                    (exon.start, exon.end, exon) for exon in grp)

        return cls(transcript_trees, exon_trees)
コード例 #26
0
def site_intervaltree(seq, enzyme):
    """
    Initialise an intervaltree representation of an enzyme's cutsites across a given sequence.
    Whether a position involves a cutsite can be queried with tree[x] or tree[x1:x2].
    :param seq: the sequence to digest
    :param enzyme: the restriction enzyme used in digestion
    :return: an intervaltree of cutsites.
    """
    tr = IntervalTree()
    size = enzyme.size
    offset = enzyme.fst3
    for si in enzyme.search(seq):
        start = si + offset - 1
        tr.addi(start, start + size)
    return tr
コード例 #27
0
ファイル: copy_test.py プロジェクト: chaimleib/intervaltree
def test_copy_cast():
    t = IntervalTree.from_tuples(data.ivs1.data)

    tcopy = IntervalTree(t)
    tcopy.verify()
    assert t == tcopy

    tlist = list(t)
    for iv in tlist:
        assert iv in t
    for iv in t:
        assert iv in tlist

    tset = set(t)
    assert tset == t.items()
コード例 #28
0
ファイル: copy_test.py プロジェクト: ProgVal/intervaltree
def test_copy_cast():
    t = trees['ivs1']()

    tcopy = IntervalTree(t)
    tcopy.verify()
    assert t == tcopy

    tlist = list(t)
    for iv in tlist:
        assert iv in t
    for iv in t:
        assert iv in tlist

    tset = set(t)
    assert tset == t.items()
コード例 #29
0
def test_update():
    t = IntervalTree()
    interval = Interval(0, 1)
    s = set([interval])

    t.update(s)
    assert isinstance(t, IntervalTree)
    assert len(t) == 1
    assert set(t).pop() == interval

    interval = Interval(2, 3)
    t.update([interval])
    assert isinstance(t, IntervalTree)
    assert len(t) == 2
    assert sorted(t)[1] == interval
コード例 #30
0
ファイル: cgdensity.py プロジェクト: DMU-lilab/DNAMethylation
def _cgi_overlap(cgis, regions):
	cgiInterval = IntervalTree(Interval(cg[0], cg[1]) for cg in cgis)

	vcgi = []
	vnoncgi = []
	vvalley = []
	for region in regions:
		if region[5] == "VALLEY":
			vvalley += [region[4]]
		else:
			if cgiInterval.overlaps(region[1], region[2]):
				vcgi += [region[4]]
			else:
				vnoncgi += [region[4]]

	return(vcgi, vnoncgi, vvalley)
コード例 #31
0
class GenomeAnnotation(object):
    """
    represents a genbank file
    and allows to efficiently annotate
    positions of interest
    """
    COLUMNS = [
        "type", "name", "locus", "product", 'protein_id', "strand", "start",
        "end"
    ]

    def __init__(self, genbank_file):
        """
        initializes the GenomeAnnotation object

        :param genbank_file: a path to a genbank file
        """
        self.genome_tree = IntervalTree()
        self.gene_dic = {}
        self.locus_dic = {}
        self.type_dic = {}
        self.genome_id = None
        self.length = None
        self.__read_genbank(genbank_file)

        # internal data structure for quick internal nearest gene search if position is not annotated
        tmp = []

        for v in (self.type_dic["CDS"] + self.type_dic["gene"]):
            tmp.extend([(v.start, v), (v.end, v)])

        tmp.sort(key=lambda x: x[0])
        self.__index_list = []
        self.__cds_list = []
        for pos, cds in tmp:
            self.__index_list.append(pos)
            self.__cds_list.append(cds)

    def __read_genbank(self, genbank_file):
        """
        reads the genbank file and stores its content in a interval tree
        and other searchable containers for efficient querying

        :param genbank_file: a path to a genbank file
        """
        print("old implementation")
        pseudogenes = []
        with open(genbank_file, "r") as f:
            type, name, locus, product, product_id, strand, start, end = None, None, None, None, None, None, None, None

            annotated_features = set()

            # states
            gathering = False
            comment_block = False
            annotation_block = False
            c = 0
            for l in f:
                # skip empty lines
                if l.strip() == "":
                    continue

                splits = l.split()

                if splits[0].startswith("LOCUS"):
                    print(splits)
                    self.genome_id = splits[1].strip()
                    self.length = int(splits[2].strip())

                # are we at the end of the annotation block?
                if splits[0].startswith("ORIGIN"):
                    break

                # check for parsing stage
                if splits[0].startswith("COMMENT"):
                    comment_block = True

                if splits[0].startswith("FEATURES"):
                    print(annotated_features)
                    annotation_block = True
                    comment_block = False

                # COMMENT block feature annotation
                if comment_block and splits[0].startswith("Fe"):

                    gathering = True
                    for an in splits[3:]:
                        if not an.startswith("Gene"):
                            annotated_features.add(an.split(";")[0])
                        else:
                            annotated_features.add("gene")

                # FEATURES Block here we found an entry that we want to gather
                if annotation_block and splits[
                        0] in annotated_features and ".." in splits[1]:

                    # first add already gathered entry into data structures
                    if locus is not None:
                        entry = GenomeEntry(type, name, locus, product,
                                            product_id, strand, start, end)

                        #if type == "PROMOTER":
                        #    print(entry)
                        # if its a gene annotation than first store it in temp for alter processing
                        if type == "gene":
                            pseudogenes.append(entry)
                        else:
                            if start > end:
                                print(entry)
                                c += 1
                                self.genome_tree.addi(start, self.length,
                                                      entry)
                                self.genome_tree.addi(0, end, entry)
                            else:
                                self.genome_tree.addi(start, end, entry)

                            self.locus_dic[locus] = entry
                            self.type_dic.setdefault(type, []).append(entry)

                            if name is not None:
                                self.gene_dic[name] = entry

                        type, name, locus, product, product_id, strand, start, end = None, None, None, None, None, None, None, None

                    gathering = True
                    type = splits[0]
                    # determine strand, start and end

                    if splits[1].startswith('comp'):
                        interval = splits[1].strip('complement()')
                        strand = '-'
                    else:
                        interval = splits[1]
                        strand = '+'
                    start, end = map(lambda x: int(x) - 1,
                                     interval.split('..'))
                    # TODO: this has to be fixed in the genbank file
                    if start == end:
                        end += 1

                # gather annotated elements
                if gathering:

                    # if we are in the comment block than we are gathering annotated features
                    if comment_block:
                        if "::" in splits:
                            gathering = False
                        else:
                            for s in splits:
                                annotated_features.add(s.split(";")[0])

                    # if we are in the annotation block than we gather infos distributed over multiple lines
                    if annotation_block:
                        if splits[0].startswith("/locus"):
                            locus = l.split("=")[-1].replace('"', '').replace(
                                "_", "").strip()
                        elif splits[0].startswith("/product"):
                            product = l.split("=")[-1].replace('"', '').strip()
                        elif splits[0].startswith("/gene"):
                            name = l.split("=")[-1].replace('"', '').strip()
                        elif splits[0].startswith("/protein_id"):
                            product_id = l.split("=")[-1].replace('"',
                                                                  '').strip()
                        else:
                            continue

            # end of file
            if locus is not None:
                entry = GenomeEntry(type, name, locus, product, product_id,
                                    strand, start, end)
                # if its a gene annotation than first store it in temp for alter processing
                #if type == "PROMOTER":
                #    print(entry)
                if type == "gene":
                    pseudogenes.append(entry)
                else:
                    start = entry.start
                    end = entry.end
                    if start > end:
                        print(entry)
                        c += 1
                        self.genome_tree.addi(start, self.length, entry)
                        self.genome_tree.addi(0, end, entry)
                    else:
                        self.genome_tree.addi(entry.start, entry.end, entry)

                    self.locus_dic[locus] = entry
                    self.type_dic.setdefault(type, []).append(entry)

                    if name is not None:
                        self.gene_dic[name] = entry
            print("Wrongly start end", c)
            for p in pseudogenes:
                # if this is true gene did not have another entry
                if p.locus not in self.locus_dic:
                    self.locus_dic[p.locus] = p
                    self.type_dic.setdefault(p.type, []).append(p)
                    self.genome_tree.addi(p.start, p.end, p)
                    if p.name is not None:
                        self.gene_dic[p.name] = p

    def _read_genbank2(self, genbank_file):

        gene_tmp = []
        nop = [None]
        with open(genbank_file, "r") as gbk:
            anno = SeqIO.read(gbk, "genbank")
            self.genome_id = anno.id
            self.length = len(anno)

            for rec in anno.features:
                if rec.type == "source":
                    continue
                else:
                    entry = GenomeEntry(
                        rec.type,
                        rec.qualifiers.get("gene", nop)[0],
                        rec.qualifiers.get("locus_tag", nop)[0],
                        rec.qualifiers.get("product", nop)[0],
                        rec.qualifiers.get("protein_id", nop)[0],
                        "+" if rec.strand else "-",
                        int(rec.location.start) - 1,
                        int(rec.location.end) - 1)
                    if entry.type == "gene":
                        gene_tmp.append(entry)
                    else:
                        start = entry.start
                        end = entry.end
                        if start > end:
                            self.genome_tree.addi(start, self.length, entry)
                            self.genome_tree.addi(0, end, entry)
                        else:
                            self.genome_tree.addi(entry.start, entry.end,
                                                  entry)

                        self.locus_dic[entry.locus] = entry
                        self.type_dic.setdefault(entry.type, []).append(entry)
                        if entry.name is not None:
                            self.gene_dic[entry.name] = entry

            for p in gene_tmp:
                # if this is true gene did not have another entry
                if p.locus not in self.locus_dic:
                    self.locus_dic[p.locus] = p
                    self.type_dic.setdefault(p.type, []).append(p)
                    self.genome_tree.addi(p.start, p.end, p)
                    if p.name is not None:
                        self.gene_dic[p.name] = p

    def __str__(self):
        return pd.DataFrame.from_records(list(self.locus_dic.values()),
                                         columns=self.COLUMNS).to_string()

    def annotate_positions(self, idx, aggregate=False):
        """
        annotates a list of positions with their associated genomic entries
        and returns a pandas dataframe with rows:

        pos, type, locus, name, product, strand, closest, distance

        :param idx: list of indices
        :return: pandas dataframe
        """

        # test if parameter is an iterable or int
        if isinstance(idx, int):
            idx = [idx]
        else:
            idx = list(set(idx))

        unknown = GenomeEntry("?", None, None, None, None, None, None, None)
        entries = []
        closest = []
        distance = []
        index = []
        for i in idx:
            data = self.genome_tree.search(i, strict=True)
            if data:
                # possible overlap of gene entries?
                for p in data:
                    #print(i, p.data)
                    index.append(i)
                    entries.append(p.data)
                    closest.append(None)
                    distance.append(None)
            else:
                # position is not annotated in GenomeAnnotation
                # find closest annotated CDS
                index.append(i)
                entries.append(unknown)
                i_clos = self.find_closest_gene(i)
                closest.append(i_clos.locus)
                distance.append(min(abs(i - i_clos.start),
                                    abs(i - i_clos.end)))

        anno_df = pd.DataFrame.from_records(entries, columns=self.COLUMNS)

        anno_df["pos"] = index
        anno_df["closest"] = closest
        anno_df["distance"] = distance

        if aggregate:
            anno_df = anno_df.groupby("pos").agg(
                lambda col: ';'.join(map(str, col)))
            anno_df.reset_index(inplace=True)
            print(anno_df.head())
        return anno_df[[
            "pos", "type", "locus", "name", "product", "protein_id", "strand",
            "closest", "distance", "start", "end"
        ]]

    def find_closest_gene(self, pos):
        """
        Returns closest value to pos.
        If two numbers are equally close, return the smallest number.

        :param pos: the genome position
        :return: GenomeEntry
        """
        idx = bisect_left(self.__index_list, pos)
        if idx == 0:
            return self.__cds_list[0]
        if idx == len(self.__index_list):
            return self.__cds_list[-1]
        before = self.__index_list[idx - 1]
        after = self.__index_list[idx]
        if after - pos < pos - before:
            return self.__cds_list[idx]
        else:
            return self.__cds_list[idx - 1]

    def annotate_genes(self, genes):
        """
        annotates a list of gene and returns a pandas dataframe
        with the following columns:

        type name locus product strand start end

        :param genes: list of genes names
        :return: pandas dataframe
        """
        if isinstance(genes, str):
            genes = [genes]

        entries = [self.gene_dic[g] for g in genes if g in self.gene_dic]
        return pd.DataFrame.from_records(entries, columns=self.COLUMNS)

    def annotate_loci(self, loci):
        """
        annotates a list of loci tags and returns a pandas dataframe
        with the following columns:

        type name locus product strand start end

        :param loci: list of locus names
        :return: pandas dataframe
        """
        if isinstance(loci, str):
            loci = [loci]

        entries = [self.locus_dic[g] for g in loci if g in self.locus_dic]
        return pd.DataFrame.from_records(entries, columns=self.COLUMNS)

    def annotate_type(self, types):
        """
        annotates a list of types  and returns a pandas dataframe
        with the following columns:

        type name locus product strand start end

        :param types: list of types
        :return: pandas dataframe
        """
        if isinstance(types, str):
            types = [types]

        entries = []
        for g in types:
            if g in self.type_dic:
                for e in self.type_dic[g]:
                    entries.append(e)
        return pd.DataFrame.from_records(entries, columns=self.COLUMNS)

    def annotate_dataframe(self,
                           df,
                           column,
                           suffix=("_x", "_y"),
                           aggregate=False):
        """
        annotate an existing dataframe

        :param df: data frame to which annotation is added
        :param column: specifies the genome position column
        :param suffix: tuple of suffix that is added overlapping column names (default: (_x, _y))
        :param aggregate: determines whether duplicated entry are aggregated as a semicolon separated string
        :return: pandas dataframe
        """
        idx = set(df[column])
        pos_df = self.annotate_positions(idx, aggregate=aggregate)

        df = df.merge(pos_df,
                      left_on=column,
                      right_on="pos",
                      how="inner",
                      suffixes=suffix)
        df.drop("pos", axis=1, inplace=True)

        return df
コード例 #32
0
    def add_slides_with_annotations(self):
        def data_reducer(a, b):
            return a + b

        layer = self._add_layer('Slides')

        doc = ET.parse(os.path.join(self.opts.basedir, 'shapes.svg'))

        for img in doc.iterfind('./{http://www.w3.org/2000/svg}image'):
            path = img.get('{http://www.w3.org/1999/xlink}href')
            img.set('{http://www.w3.org/1999/xlink}href',
                    os.path.join(self.opts.basedir, path))
            if path.endswith('/deskshare.png'):
                continue

            img_width = int(img.get('width'))
            img_height = int(img.get('height'))

            canvas = doc.find(
                './{{http://www.w3.org/2000/svg}}g[@class="canvas"][@image="{}"]'
                .format(img.get('id')))

            img_start = round(float(img.get('in')) * Gst.SECOND)
            img_end = round(float(img.get('out')) * Gst.SECOND)

            t = IntervalTree()
            t.add(Interval(begin=img_start, end=img_end, data=[]))

            if canvas is None:
                svg = ET.XML(
                    '<svg version="1.1" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {} {}"></svg>'
                    .format(img_width, img_height))
                svg.append(img)

                pngpath = os.path.join(self.opts.basedir,
                                       '{}.png'.format(img.get('id')))

                if not os.path.exists(pngpath):
                    cairosvg.svg2png(bytestring=ET.tostring(svg).decode(
                        'utf-8').encode('utf-8'),
                                     write_to=pngpath,
                                     output_width=img_width,
                                     output_height=img_height)

                asset = self._get_asset(pngpath)
                width, height = self._constrain(
                    self._get_dimensions(asset),
                    (self.slides_width, self.opts.height))
                self._add_clip(layer, asset, img_start, 0, img_end - img_start,
                               0, 0, width, height)

            else:
                shapes = {}
                for shape in canvas.iterfind(
                        './{http://www.w3.org/2000/svg}g[@class="shape"]'):

                    shape_style = shape.get('style')
                    shape.set('style',
                              shape_style.replace('visibility:hidden;', ''))

                    for shape_img in shape.iterfind(
                            './{http://www.w3.org/2000/svg}image'):
                        print(ET.tostring(shape_img))
                        shape_img_path = shape_img.get(
                            '{http://www.w3.org/1999/xlink}href')
                        shape_img.set(
                            '{http://www.w3.org/1999/xlink}href',
                            os.path.join(self.opts.basedir, shape_img_path))

                    start = img_start
                    timestamp = shape.get('timestamp')
                    shape_start = round(float(timestamp) * Gst.SECOND)
                    if shape_start > img_start:
                        start = shape_start

                    end = img_end
                    undo = shape.get('undo')
                    shape_end = round(float(undo) * Gst.SECOND)
                    if undo != '-1' and shape_end != 0 and shape_end < end:
                        end = shape_end

                    if end < start:
                        continue

                    shape_id = shape.get('shape')
                    if shape_id in shapes:
                        shapes[shape_id].append({
                            'start': start,
                            'end': end,
                            'shape': shape
                        })
                    else:
                        shapes[shape_id] = [{
                            'start': start,
                            'end': end,
                            'shape': shape
                        }]

                for shape_id, shapes_list in shapes.items():
                    sorted_shapes = sorted(shapes_list,
                                           key=lambda k: k['start'])
                    index = 1
                    for s in sorted_shapes:
                        if index < len(shapes_list):
                            s['end'] = sorted_shapes[index]['start']
                        t.add(
                            Interval(begin=s['start'],
                                     end=s['end'],
                                     data=[(shape_id, s['shape'])]))
                        index += 1

                t.split_overlaps()
                t.merge_overlaps(data_reducer=data_reducer)
                for index, interval in enumerate(sorted(t)):
                    svg = ET.XML(
                        '<svg version="1.1" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {} {}"></svg>'
                        .format(img_width, img_height))
                    svg.append(img)

                    for shape_id, shape in sorted(interval.data,
                                                  key=lambda k: k[0]):
                        svg.append(shape)

                    pngpath = os.path.join(
                        self.opts.basedir,
                        '{}-{}.png'.format(img.get('id'), index))

                    if not os.path.exists(pngpath):
                        cairosvg.svg2png(bytestring=ET.tostring(svg).decode(
                            'utf-8').encode('utf-8'),
                                         write_to=pngpath,
                                         output_width=img_width,
                                         output_height=img_height)

                    asset = self._get_asset(pngpath)
                    width, height = self._constrain(
                        self._get_dimensions(asset),
                        (self.slides_width, self.opts.height))
                    self._add_clip(layer, asset, interval.begin, 0,
                                   interval.end - interval.begin, 0, 0, width,
                                   height)
コード例 #33
0
ファイル: calc_wer_yandex.py プロジェクト: tyommik/calc_wer
 def __init__(self, path: pathlib.Path):
     self.path = path
     self.data = []
     self.tree = IntervalTree()
     self.__load()
コード例 #34
0
from datetime import datetime, date
from intervaltree import IntervalTree

class ScheduleItem:
    def __init__(self, course_number, start_time, end_time):
        self.course_number = course_number
        self.start_time = start_time
        self.end_time = end_time
    def get_begin(self):
        return minutes_from_midnight(self.start_time)
    def get_end(self):
        return minutes_from_midnight(self.end_time)
    def __repr__(self):
        return ''.join(["{ScheduleItem: ", str((self.course_number, self.start_time, self.end_time)), "}"])

def minutes_from_midnight(time):
    str_time = datetime.strptime(time, '%I:%M%p').time()
    midnight = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
    return int((datetime.combine(date.today(), str_time) - midnight).total_seconds()/60)

T = IntervalTree([ScheduleItem(28374, "9:00AM", "10:00AM"), \
                  ScheduleItem(43564, "8:00AM", "12:00PM"), \
                  ScheduleItem(53453, "1:00AM", "2:00AM")])
print T.search(minutes_from_midnight("9:00PM"), minutes_from_midnight("10:00PM"))
コード例 #35
0
    def createGraphFromSubmaps(self, submaps):
        x_intervals = IntervalTree()
        y_intervals = IntervalTree()

        for num, submap in enumerate(submaps):
            x_intervals[submap.min_x:submap.max_x + 1] = num
            y_intervals[submap.min_y:submap.max_y + 1] = num

        # print(x_intervals)
        # print(y_intervals)

        adj_list = [[] for i in range(len(submaps))]

        for num, a in enumerate(submaps):

            left_edge = x_intervals[a.min_x - 1]
            right_edge = x_intervals[a.max_x + 1]
            top_edge = y_intervals[a.min_y - 1]
            bottom_edge = y_intervals[a.max_y + 1]

            # Find all rectangles connected to the left edge
            for interval in left_edge:
                b = submaps[interval.data]

                if a.min_y >= b.min_y:
                    if a.min_y <= b.max_y:
                        # Case A or C
                        adj_list[num].append(interval.data)
                elif a.max_y >= b.min_y:
                    # Case B
                    adj_list[num].append(interval.data)
                elif a.max_y >= b.max_y:
                    # Case D
                    adj_list[num].append(interval.data)

            # Right edge
            for interval in right_edge:
                b = submaps[interval.data]

                if a.min_y >= b.min_y:
                    if a.min_y <= b.max_y:
                        # Case A or C
                        adj_list[num].append(interval.data)
                elif a.max_y >= b.min_y:
                    # Case B
                    adj_list[num].append(interval.data)
                elif a.max_y >= b.max_y:
                    # Case D
                    adj_list[num].append(interval.data)

            # Top Edge
            for interval in top_edge:
                b = submaps[interval.data]

                if a.min_x >= b.min_x:
                    if a.min_x <= b.max_x:
                        # Case A or C
                        adj_list[num].append(interval.data)
                elif a.max_x >= b.min_x:
                    # Case B
                    adj_list[num].append(interval.data)
                elif a.max_x >= b.max_x:
                    # Case D
                    adj_list[num].append(interval.data)

            # Bottom Edge
            for interval in bottom_edge:
                b = submaps[interval.data]

                if a.min_x >= b.min_x:
                    if a.min_x <= b.max_x:
                        # Case A or C
                        adj_list[num].append(interval.data)
                elif a.max_x >= b.min_x:
                    # Case B
                    adj_list[num].append(interval.data)
                elif a.max_x >= b.max_x:
                    # Case D
                    adj_list[num].append(interval.data)

        return adj_list
コード例 #36
0
 def empty(cls):
     return cls("", 0, -1, "", IntervalTree())
コード例 #37
0
    def process_link_file(self):
        # the file format expected is similar to file format of links in
        # circos:
        # chr1 100 200 chr1 250 300 0.5
        # where the last value is a score.
        valid_intervals = 0
        interval_tree = {}
        line_number = 0
        has_score = True
        max_score = float('-inf')
        min_score = float('inf')
        with open(self.properties['file'], 'r') as file_h:
            for line in file_h.readlines():
                line_number += 1
                if line.startswith('browser') or line.startswith(
                        'track') or line.startswith('#'):
                    continue
                try:
                    chrom1, start1, end1, chrom2, start2, end2 = line.strip(
                    ).split('\t')[:6]
                except Exception as detail:
                    raise InputError(
                        'File not valid. The format is chrom1 start1, end1, '
                        'chrom2, start2, end2\nError: {}\n in line\n {}'.
                        format(detail, line))
                try:
                    score = line.strip().split('\t')[6]
                except IndexError:
                    has_score = False
                    score = np.nan

                try:
                    start1 = int(start1)
                    end1 = int(end1)
                    start2 = int(start2)
                    end2 = int(end2)
                except ValueError as detail:
                    raise InputError(
                        "Error reading line: {}. One of the fields is not "
                        "an integer.\nError message: {}".format(
                            line_number, detail))

                assert start1 <= end1, "Error in line #{}, end1 larger than start1 in {}".format(
                    line_number, line)
                assert start2 <= end2, "Error in line #{}, end2 larger than start2 in {}".format(
                    line_number, line)

                if has_score:
                    try:
                        score = float(score)
                    except ValueError as detail:
                        self.log.warning(
                            "Warning: reading line: {}. The score is not valid {} will not be used. "
                            "\nError message: {}".format(
                                line_number, score, detail))
                        score = np.nan
                        has_score = False
                    else:
                        if score < min_score:
                            min_score = score
                        if score > max_score:
                            max_score = score

                if chrom1 != chrom2:
                    self.log.warning(
                        "Only links in same chromosome are used. Skipping line\n{}\n"
                        .format(line))
                    continue

                if chrom1 not in interval_tree:
                    interval_tree[chrom1] = IntervalTree()

                if start2 < start1:
                    start1, start2 = start2, start1
                    end1, end2 = end2, end1

                # each interval spans from the smallest start to the largest end
                interval_tree[chrom1].add(
                    Interval(start1, end2,
                             [start1, end1, start2, end2, score]))
                valid_intervals += 1

        if valid_intervals == 0:
            self.log.warning("No valid intervals were found in file {}".format(
                self.properties['file']))

        file_h.close()
        return (interval_tree, min_score, max_score, has_score)
コード例 #38
0
                result_gt = [from_model[x]
                             for x in np.greater(results[:, 0], results[:, 1])]
            dp_s = dp_set[0][s]
            dp_s = dp_s[dp_s > 0]
            results = ((a, b, c, d, e)
                       for (a, b, c), d, e in zip(sample_gt, result_gt, dp_s))

            results = list(results)

            dp_avg = ""
            chrom_sets = [list(g) for k, g in groupby(results, itemgetter(0))]
            tree[sample] = {}
            for chrom_set in chrom_sets:
                end = 0
                chrom = chrom_set[0][0]
                tree[sample][chrom] = IntervalTree()
                interval_set = [list(g) for k, g in groupby(chrom_set, itemgetter(3))]
                interval_set_len = len(interval_set)
                for interval_n, interval in enumerate(interval_set):
                    orig = [x[2] for x in interval]
                    pred = [x[3] for x in interval]
                    gt = pred[0]
                    orig_RLE, switches = generate_RLE(orig)
                    supporting_sites = len([x for x in interval if x[2] == x[3]])
                    dp_avg = 0
                    if supporting_sites > 0:
                        dp_avg = sum([x[4] for x in interval if x[2] == x[3]])*1.0/supporting_sites
                    gt = interval[0][3]
                    if supporting_sites > 0:
                        if args["--infill"]:
                            start = end + 1
コード例 #39
0
    "20",
    "21",
    "22",
    "X",
    "Y",
    "MT",
)

# Maps chromosomes to integers
CHROMOSOME_INTEGERS = {chrom: i + 1 for i, chrom in enumerate(CHROMOSOMES)}

PAR_COORDINATES = {
    "37": {
        "X":
        IntervalTree([
            Interval(60001, 2699521, "par1"),
            Interval(154931044, 155260561, "par2")
        ]),
        "Y":
        IntervalTree([
            Interval(10001, 2649521, "par1"),
            Interval(59034050, 59363567, "par2")
        ]),
    },
    "38": {
        "X":
        IntervalTree([
            Interval(10001, 2781480, "par1"),
            Interval(155701383, 156030896, "par2")
        ]),
        "Y":
        IntervalTree([
コード例 #40
0
class MemoryCache(object):
    """! @brief Memory cache.
    
    Maintains a cache of target memory. The constructor is passed a backing DebugContext object that
    will be used to fill the cache.
    
    The cache is invalidated whenever the target has run since the last cache operation (based on run
    tokens). If the target is currently running, all accesses cause the cache to be invalidated.
    
    The target's memory map is referenced. All memory accesses must be fully contained within a single
    memory region, or a TransferFaultError will be raised. However, if an access is outside of all regions,
    the access is passed to the underlying context unmodified. When an access is within a region, that
    region's cacheability flag is honoured.
    """
    def __init__(self, context, core):
        self._context = context
        self._core = core
        self._run_token = -1
        self._reset_cache()

    def _reset_cache(self):
        self._cache = IntervalTree()
        self._metrics = CacheMetrics()

    def _check_cache(self):
        """! @brief Invalidates the cache if appropriate."""
        if self._core.is_running():
            LOG.debug("core is running; invalidating cache")
            self._reset_cache()
        elif self._run_token != self._core.run_token:
            self._dump_metrics()
            LOG.debug("out of date run token; invalidating cache")
            self._reset_cache()
            self._run_token = self._core.run_token

    def _get_ranges(self, addr, count):
        """! @brief Splits a memory address range into cached and uncached subranges.
        @return Returns a 2-tuple with the first element being a set of Interval objects for each
          of the cached subranges. The second element is a set of Interval objects for each of the
          non-cached subranges.
        """
        cached = self._cache.overlap(addr, addr + count)
        uncached = {Interval(addr, addr + count)}
        for cachedIv in cached:
            newUncachedSet = set()
            for uncachedIv in uncached:

                # No overlap.
                if cachedIv.end < uncachedIv.begin or cachedIv.begin > uncachedIv.end:
                    newUncachedSet.add(uncachedIv)
                    continue

                # Begin segment.
                if cachedIv.begin - uncachedIv.begin > 0:
                    newUncachedSet.add(
                        Interval(uncachedIv.begin, cachedIv.begin))

                # End segment.
                if uncachedIv.end - cachedIv.end > 0:
                    newUncachedSet.add(Interval(cachedIv.end, uncachedIv.end))
            uncached = newUncachedSet
        return cached, uncached

    def _read_uncached(self, uncached):
        """! "@brief Reads uncached memory ranges and updates the cache.
        @return A list of Interval objects is returned. Each Interval has its @a data attribute set
          to a bytearray of the data read from target memory.
        """
        uncachedData = []
        for uncachedIv in uncached:
            data = self._context.read_memory_block8(
                uncachedIv.begin, uncachedIv.end - uncachedIv.begin)
            iv = Interval(uncachedIv.begin, uncachedIv.end, bytearray(data))
            self._cache.add(iv)  # TODO merge contiguous cached intervals
            uncachedData.append(iv)
        return uncachedData

    def _update_metrics(self, cached, uncached, addr, size):
        cachedSize = 0
        for iv in cached:
            begin = iv.begin
            end = iv.end
            if iv.begin < addr:
                begin = addr
            if iv.end > addr + size:
                end = addr + size
            cachedSize += end - begin

        uncachedSize = sum((iv.end - iv.begin) for iv in uncached)

        self._metrics.reads += 1
        self._metrics.hits += cachedSize
        self._metrics.misses += uncachedSize

    def _dump_metrics(self):
        if self._metrics.total > 0:
            LOG.debug(
                "%d reads, %d bytes [%d%% hits, %d bytes]; %d bytes written",
                self._metrics.reads, self._metrics.total,
                self._metrics.percent_hit, self._metrics.hits,
                self._metrics.writes)
        else:
            LOG.debug("no reads")

    def _read(self, addr, size):
        """! @brief Performs a cached read operation of an address range.
        @return A list of Interval objects sorted by address.
        """
        # Get the cached and uncached subranges of the requested read.
        cached, uncached = self._get_ranges(addr, size)
        self._update_metrics(cached, uncached, addr, size)

        # Read any uncached ranges.
        uncachedData = self._read_uncached(uncached)

        # Merged cached with data we just read
        combined = list(cached) + uncachedData
        combined.sort(key=lambda x: x.begin)
        return combined

    def _merge_data(self, combined, addr, size):
        """! @brief Extracts data from the intersection of an address range across a list of interval objects.
        
        The range represented by @a addr and @a size are assumed to overlap the intervals. The first
        and last interval in the list may have ragged edges not fully contained in the address range, in
        which case the correct slice of those intervals is extracted.
        
        @param self
        @param combined List of Interval objects forming a contiguous range. The @a data attribute of
          each interval must be a bytearray.
        @param addr Start address. Must be within the range of the first interval.
        @param size Number of bytes. (@a addr + @a size) must be within the range of the last interval.
        @return A single bytearray object with all data from the intervals that intersects the address
          range.
        """
        result = bytearray()
        resultAppend = bytearray()

        # Check for fully contained subrange.
        if len(combined) and combined[0].begin < addr and combined[
                0].end > addr + size:
            offset = addr - combined[0].begin
            endOffset = offset + size
            result = combined[0].data[offset:endOffset]
            return result

        # Take slice of leading ragged edge.
        if len(combined) and combined[0].begin < addr:
            offset = addr - combined[0].begin
            result += combined[0].data[offset:]
            combined = combined[1:]
        # Take slice of trailing ragged edge.
        if len(combined) and combined[-1].end > addr + size:
            offset = addr + size - combined[-1].begin
            resultAppend = combined[-1].data[:offset]
            combined = combined[:-1]

        # Merge.
        for iv in combined:
            result += iv.data
        result += resultAppend

        return result

    def _update_contiguous(self, cached, addr, value):
        size = len(value)
        end = addr + size
        leadBegin = addr
        leadData = bytearray()
        trailData = bytearray()
        trailEnd = end

        if cached[0].begin < addr and cached[0].end > addr:
            offset = addr - cached[0].begin
            leadData = cached[0].data[:offset]
            leadBegin = cached[0].begin
        if cached[-1].begin < end and cached[-1].end > end:
            offset = end - cached[-1].begin
            trailData = cached[-1].data[offset:]
            trailEnd = cached[-1].end

        self._cache.remove_overlap(addr, end)

        data = leadData + value + trailData
        self._cache.addi(leadBegin, trailEnd, data)

    def _check_regions(self, addr, count):
        """! @return A bool indicating whether the given address range is fully contained within
              one known memory region, and that region is cacheable.
        @exception TransferFaultError Raised if the access is not entirely contained within a single region.
        """
        regions = self._core.memory_map.get_intersecting_regions(addr,
                                                                 length=count)

        # If no regions matched, then allow an uncached operation.
        if len(regions) == 0:
            return False

        # Raise if not fully contained within one region.
        if len(regions) > 1 or not regions[0].contains_range(addr,
                                                             length=count):
            raise TransferFaultError(
                "individual memory accesses must not cross memory region boundaries"
            )

        # Otherwise return whether the region is cacheable.
        return regions[0].is_cacheable

    def read_memory(self, addr, transfer_size=32, now=True):
        # TODO use more optimal underlying read_memory calls
        if transfer_size == 8:
            data = self.read_memory_block8(addr, 1)[0]
        else:
            data = conversion.byte_list_to_nbit_le_list(
                self.read_memory_block8(addr, transfer_size // 8),
                transfer_size)[0]

        if now:
            return data
        else:

            def read_cb():
                return data

            return read_cb

    def read_memory_block8(self, addr, size):
        if size <= 0:
            return []

        self._check_cache()

        # Validate memory regions.
        if not self._check_regions(addr, size):
            LOG.debug("range [%x:%x] is not cacheable", addr, addr + size)
            return self._context.read_memory_block8(addr, size)

        # Get the cached and uncached subranges of the requested read.
        combined = self._read(addr, size)

        # Extract data out of combined intervals.
        result = list(self._merge_data(combined, addr, size))
        assert len(
            result) == size, "result size ({}) != requested size ({})".format(
                len(result), size)
        return result

    def read_memory_block32(self, addr, size):
        return conversion.byte_list_to_u32le_list(
            self.read_memory_block8(addr, size * 4))

    def write_memory(self, addr, value, transfer_size=32):
        if transfer_size == 8:
            return self.write_memory_block8(addr, [value])
        else:
            return self.write_memory_block8(
                addr,
                conversion.nbit_le_list_to_byte_list([value], transfer_size))

    def write_memory_block8(self, addr, value):
        if len(value) <= 0:
            return

        self._check_cache()

        # Validate memory regions.
        cacheable = self._check_regions(addr, len(value))

        # Write to the target first, so if it fails we don't update the cache.
        result = self._context.write_memory_block8(addr, value)

        if cacheable:
            size = len(value)
            end = addr + size
            cached = sorted(self._cache.overlap(addr, end),
                            key=lambda x: x.begin)
            self._metrics.writes += size

            if len(cached):
                # Write data is entirely within a single cached interval.
                if addr >= cached[0].begin and end <= cached[0].end:
                    beginOffset = addr - cached[0].begin
                    endOffset = beginOffset + size
                    cached[0].data[beginOffset:endOffset] = value

                else:
                    self._update_contiguous(cached, addr, bytearray(value))
            else:
                # No cached data in this range, so just add the entire interval.
                self._cache.addi(addr, end, bytearray(value))

        return result

    def write_memory_block32(self, addr, data):
        return self.write_memory_block8(
            addr, conversion.u32le_list_to_byte_list(data))

    def invalidate(self):
        self._reset_cache()
コード例 #41
0
 def _reset_cache(self):
     self._cache = IntervalTree()
     self._metrics = CacheMetrics()
コード例 #42
0
class TemporalNodeCollection(NodeCollection):
    """A collection of temporal nodes"""
    def __init__(self, *args, **kwargs) -> None:
        """Initialize the NodeCollection object."""

        # initialize the base class
        super().__init__(*args, **kwargs)

        # initialize an intervaltree to save events
        self._events = IntervalTree()

        # class of objects
        self._default_class: Any = TemporalNode

    @singledispatchmethod
    def __getitem__(self, key: Any) -> Any:
        return super().__getitem__(key)

    @__getitem__.register(slice)  # type: ignore
    @__getitem__.register(int)  # type: ignore
    @__getitem__.register(float)  # type: ignore
    def _(self, key: Union[int, float, slice]) -> Any:
        # pylint: disable=arguments-differ
        start, end, _ = _get_start_end(key)
        for start, end, uid in sorted(self._events[start:end]):
            for obj in self[uid][start:end]:
                yield obj

    @property
    def start(self):
        """start of the object"""
        return self._events.begin()

    @property
    def end(self):
        """end of the object"""
        return self._events.end()

    @property
    def events(self):
        """Temporal events"""
        return self._events

    @singledispatchmethod
    def add(self, *args, **kwargs: Any) -> None:
        """Add multiple nodes. """
        super().add(*args, **kwargs)

    def _add(self, obj: Any, **kwargs: Any) -> None:
        """Add an node to the set of nodes."""
        super()._add(obj, **kwargs)
        start, end, _ = obj.last()
        self._events[start:end] = obj.uid

    def _if_exist(self, obj: Any, **kwargs: Any) -> None:
        """Helper function if node already exists."""
        count: int = kwargs.pop('count', 1)
        element = self[obj.relations]
        element.event(**kwargs)
        start, end, _ = obj.last()
        self._events[start:end] = element.uid

    def _remove(self, obj) -> None:
        """Add an edge to the set of edges."""
        for interval in sorted(self._events):
            if interval.data == obj.uid:
                self._events.remove(interval)
        super()._remove(obj)
コード例 #43
0
    def intersect_cn_trees(self):
        def get_bands(chrom,
                      start,
                      end,
                      cytoband=os.path.dirname(__file__) +
                      '/supplement_data/cytoBand.txt'):
            bands = []
            on_c = False
            with open(cytoband, 'r') as f:
                for line in f:
                    row = line.strip('\n').split('\t')
                    if row[0].strip('chr') != str(chrom):
                        if on_c:
                            return bands
                        continue
                    if int(row[1]) <= end and int(row[2]) >= start:
                        bands.append(Cytoband(chrom, row[3]))
                        on_c = True
                    if int(row[1]) > end:
                        return bands

        def merge_cn_events(event_segs,
                            neighbors,
                            R=frozenset(),
                            X=frozenset()):
            is_max = True
            for s in itertools.chain(event_segs, X):
                if isadjacent(s, R):
                    is_max = False
                    break
            if is_max:
                bands = set.union(*(set(b[0]) for b in R))
                cns = next(iter(R))[1]
                ccf_hat = np.zeros(len(self.sample_list))
                ccf_high = np.zeros(len(self.sample_list))
                ccf_low = np.zeros(len(self.sample_list))
                for seg in R:
                    ccf_hat += np.array(seg[2])
                    ccf_high += np.array(seg[3])
                    ccf_low += np.array(seg[4])
                yield (bands, cns, ccf_hat / len(R), ccf_high / len(R),
                       ccf_low / len(R))
            else:
                for s in min(
                    (event_segs - neighbors[p] for p in event_segs.union(X)),
                        key=len):
                    if isadjacent(s, R):
                        for region in merge_cn_events(
                                event_segs.intersection(neighbors[s]),
                                neighbors,
                                R=R.union({s}),
                                X=X.intersection(neighbors[s])):
                            yield region
                        event_segs = event_segs.difference({s})
                        X = X.union({s})

        def isadjacent(s, R):
            if not R:
                return True
            Rchain = list(itertools.chain(*(b[0] for b in R)))
            minR = min(Rchain)
            maxR = max(Rchain)
            mins = min(s[0])
            maxs = max(s[0])
            if mins >= maxR:
                return mins - maxR <= 1 and mins.band[0] == maxR.band[0]
            elif maxs <= minR:
                return minR - maxs <= 1 and maxs.band[0] == minR.band[0]
            else:
                return False

        c_trees = {}
        n_samples = len(self.sample_list)
        for chrom in list(map(str, range(1, 23))) + ['X', 'Y']:
            tree = IntervalTree()
            for sample in self.sample_list:
                if sample.CnProfile:
                    tree.update(sample.CnProfile[chrom])
            tree.split_overlaps()
            tree.merge_equals(data_initializer=[],
                              data_reducer=lambda a, c: a + [c])
            c_tree = IntervalTree(
                filter(lambda s: len(s.data) == n_samples, tree))
            c_trees[chrom] = c_tree
            event_segs = set()
            for seg in c_tree:
                start = seg.begin
                end = seg.end
                bands = get_bands(chrom, start, end)
                cns_a1 = []
                cns_a2 = []
                ccf_hat_a1 = []
                ccf_hat_a2 = []
                ccf_high_a1 = []
                ccf_high_a2 = []
                ccf_low_a1 = []
                ccf_low_a2 = []
                for i, sample in enumerate(self.sample_list):
                    cns_a1.append(seg.data[i][1]['cn_a1'])
                    cns_a2.append(seg.data[i][1]['cn_a2'])
                    ccf_hat_a1.append(seg.data[i][1]['ccf_hat_a1']
                                      if seg.data[i][1]['cn_a1'] != 1 else 0.)
                    ccf_hat_a2.append(seg.data[i][1]['ccf_hat_a2']
                                      if seg.data[i][1]['cn_a2'] != 1 else 0.)
                    ccf_high_a1.append(seg.data[i][1]['ccf_high_a1']
                                       if seg.data[i][1]['cn_a1'] != 1 else 0.)
                    ccf_high_a2.append(seg.data[i][1]['ccf_high_a2']
                                       if seg.data[i][1]['cn_a2'] != 1 else 0.)
                    ccf_low_a1.append(seg.data[i][1]['ccf_low_a1']
                                      if seg.data[i][1]['cn_a1'] != 1 else 0.)
                    ccf_low_a2.append(seg.data[i][1]['ccf_low_a2']
                                      if seg.data[i][1]['cn_a2'] != 1 else 0.)
                cns_a1 = np.array(cns_a1)
                cns_a2 = np.array(cns_a2)
                if np.all(cns_a1 == 1):
                    pass
                elif np.all(cns_a1 >= 1) or np.all(cns_a1 <= 1):
                    event_segs.add(
                        (tuple(bands), tuple(cns_a1), tuple(ccf_hat_a1),
                         tuple(ccf_high_a1), tuple(ccf_low_a1), 'a1'))
                else:
                    logging.warning(
                        'Seg with inconsistent event: {}:{}:{}'.format(
                            chrom, seg.begin, seg.end))
                if np.all(cns_a2 == 1):
                    pass
                elif np.all(cns_a2 >= 1) or np.all(cns_a2 <= 1):
                    event_segs.add(
                        (tuple(bands), tuple(cns_a2), tuple(ccf_hat_a2),
                         tuple(ccf_high_a2), tuple(ccf_low_a2), 'a2'))
                else:
                    logging.warning(
                        'Seg with inconsistent event: {}:{}:{}'.format(
                            chrom, seg.begin, seg.end))
            neighbors = {s: set() for s in event_segs}
            for seg1, seg2 in itertools.combinations(event_segs, 2):
                s1_hat = np.array(seg1[2])
                s2_hat = np.array(seg2[2])
                if seg1[1] == seg2[1] and np.all(s1_hat >= np.array(seg2[4])) and np.all(s1_hat <= np.array(seg2[3]))\
                and np.all(s2_hat >= np.array(seg1[4])) and np.all(s2_hat <= np.array(seg1[3])):
                    neighbors[seg1].add(seg2)
                    neighbors[seg2].add(seg1)

            event_cache = []
            if event_segs:
                for bands, cns, ccf_hat, ccf_high, ccf_low in merge_cn_events(
                        event_segs, neighbors):
                    mut_category = 'gain' if sum(cns) > len(
                        self.sample_list) else 'loss'
                    a1 = (mut_category, bands) not in event_cache
                    if a1:
                        event_cache.append((mut_category, bands))
                    self._add_cn_event_to_samples(chrom,
                                                  min(bands),
                                                  max(bands),
                                                  cns,
                                                  mut_category,
                                                  ccf_hat,
                                                  ccf_high,
                                                  ccf_low,
                                                  a1,
                                                  dupe=not a1)
        self.concordant_cn_tree = c_trees
コード例 #44
0
    def __init__(self,
                 indiv_name='Indiv1',
                 sample_map={},
                 ccf_grid_size=101,
                 driver_genes_file=os.path.join(
                     os.path.dirname(__file__),
                     'supplement_data/Driver_genes_v1.0.txt'),
                 impute_missing=False,
                 artifact_blacklist=os.path.join(
                     os.path.dirname(__file__),
                     'supplement_data/Blacklist_SNVs.txt'),
                 artifact_whitelist='',
                 use_indels=False,
                 min_coverage=8,
                 delete_auto_bl=False,
                 PoN_file=False):

        #DECLARATIONS
        #@properties

        self.indiv_name = indiv_name
        self.sample_list = []
        """ :type : list [TumorSample]"""

        self.samples_synchronized = False

        self.driver_genes = self._parse_driver_g_file(driver_genes_file)
        # hash table to store known driver genes #TODO add enum

        self.ccf_grid_size = ccf_grid_size

        # @annatotion
        self.PatientLevel_MutBlacklist = artifact_blacklist
        self.PatientLevel_MutWhitelist = artifact_whitelist

        #Patient configuration settings

        self.impute_missing = impute_missing  # flag if to impute missing variants as ccf 0
        #self.delete_auto_bl = delete_auto_bl
        self.min_coverage = min_coverage  # min cov is specified both here and passed to tumor sample.
        self.use_indels = use_indels
        self.PoN_file = PoN_file

        self._validate_sample_names()

        #later filled data objects

        self.ND_mutations = []
        #@methods

        #storing of results
        #Clustering
        self.ClusteringResults = None

        self.MutClusters = None
        self.TruncalMutEvents = None
        self.MCMC_trace = None
        self.k_trace = None
        self.alpha_trace = None

        self.unclustered_muts = []

        # self.concordant_cn_events = []
        self.concordant_cn_tree = {
            chrom: IntervalTree()
            for chrom in list(map(str, range(1, 23))) + ['X', 'Y']
        }

        #BuildTree
        self.TopTree = None
        self.TreeEnsemble = []
コード例 #45
0
 def __init__(self, chromosome, strand, breakpoint, gene_name, exons: IntervalTree):
     self.chromosome = chromosome
     self.strand = strand
     self.breakpoint = breakpoint
     self.gene_name = gene_name
     self.exons = IntervalTree(exons)
コード例 #46
0
 def insert(self, chrom, start, end, val):
     if chrom not in self.chroms:
         self.chroms[chrom] = IntervalTree()
     self.chroms[chrom][start:end] = val
コード例 #47
0
class ExonCoords:
    def __init__(self, chromosome, strand, breakpoint, gene_name, exons: IntervalTree):
        self.chromosome = chromosome
        self.strand = strand
        self.breakpoint = breakpoint
        self.gene_name = gene_name
        self.exons = IntervalTree(exons)

    @classmethod
    def fromTuple(cls, a_tuple):
        return cls(a_tuple[0], a_tuple[1], a_tuple[2], a_tuple[3], a_tuple[4])

    @classmethod
    def copy_without_exons(cls, exc):
        return cls(exc.chromosome, exc.strand, exc.breakpoint, exc.gene_name, IntervalTree())

    @classmethod
    def empty(cls):
        return cls("", 0, -1, "", IntervalTree())

    def print_properties(self):
        print("#########################################")
        print("coordinates :", self.chromosome + ":" + str(self.exons.begin()) + "-" + str(self.exons.end()))
        print("gene        :", self.gene_name)
        print("strand      :", self._strand)
        print("breakpoint  :", self._breakpoint)
        print("exons       :", self._exons)
        print("#########################################")

    def print_as_bed(self):
        chromosome = self.chromosome
        for ex in sorted(self.exons):
            print(chromosome + "\t" + str(ex.begin) + "\t" + str(ex.end))

    @property
    def gene_name(self):
        return self._gene_name

    @gene_name.setter
    def gene_name(self, value):
        self._gene_name = value

    @property
    def chromosome(self):
        return self._chromosome

    @chromosome.setter
    def chromosome(self, value):
        self._chromosome = value

    @property
    def strand(self):
        return self._strand

    @strand.setter
    def strand(self, value):
        self._strand = value

    @property
    def breakpoint(self):  # int
        return self._breakpoint

    @breakpoint.setter
    def breakpoint(self, value):
        self._breakpoint = value

    @property
    def exons(self):  # IntervalTree()
        return self._exons

    @exons.setter
    def exons(self, exons):
        self._exons = exons

    def begin(self):
        return self.exons.begin()
コード例 #48
0
def main():
    args = cli(sys.argv[0], sys.argv[1:])
    qlens = dict()
    qintervals = defaultdict(list)

    for infile in args.infiles:
        component = psplit(splitext(infile)[0])[-1]
        with open(infile, "r") as handle:
            for query, qlen, qstart, qend in parse_gaf(handle):
                qlens[query] = qlen
                qintervals[(component, query)].append(Interval(qstart, qend))

    coverages = defaultdict(list)

    for (component, query), intervals in qintervals.items():
        itree = IntervalTree(intervals)
        itree.merge_overlaps()

        count = sum((i.end - i.begin for i in itree))
        coverages[query].append((component, count / qlens[query]))

    best_components = defaultdict(list)

    for query, components in coverages.items():
        matches = [cv for cp, cv in components]

        if len(matches) > 0:
            max_cov = max(matches)
        else:
            max_cov = 0

        if max_cov < args.min_coverage:
            best_components["unplaced"].append((query, max_cov))
        else:
            max_comp = [cp for cp, cv in components if cv == max_cov][0]
            best_components[max_comp].append((query, max_cov))

    to_reassign = []
    to_drop = set()
    for component, assigned in best_components.items():
        if component == "unplaced":
            continue

        if len(assigned) < args.min_scaffolds:
            to_reassign.extend([s for s, c in assigned])
            to_drop.add(component)

    for component in to_drop:
        del best_components[component]

    for query in to_reassign:
        components = coverages[query]
        matches = [cv for cp, cv in components if cp not in to_drop]
        if len(matches) > 0:
            max_cov = max(matches)
        else:
            max_cov = 0

        if max_cov < args.min_coverage:
            best_components["unplaced"].append((query, max_cov))
        else:
            max_comp = [cp for cp, cv in components if cv == max_cov][0]
            best_components[max_comp].append((query, max_cov))

    for component, coverages in best_components.items():
        for query, max_cov in coverages:
            print(f"{component}\t{query}\t{max_cov}", file=args.outfile)

    return
コード例 #49
0
 def copy_without_exons(cls, exc):
     return cls(exc.chromosome, exc.strand, exc.breakpoint, exc.gene_name, IntervalTree())
コード例 #50
0
        intervals.append((start, end, chrom))
    return IntervalTree.from_tuples(intervals)


with open(snakemake.input.loci_info) as instream:
    ivtree = load_loci_info(instream)

logging.info(f"Loaded {len(ivtree)} loci")

with open(snakemake.input.mask) as instream:
    mask = load_mask(instream)

logging.info(f"Loaded {len(mask)} mask regions")

masked_tree = IntervalTree(ivtree)

for iv in mask:
    masked_tree.remove_overlap(iv.begin, iv.end)

full_len = 0
for iv in ivtree:
    full_len += iv.length()

masked_len = 0
for iv in masked_tree:
    masked_len += iv.length()

logging.info(
    f"{len(ivtree)-len(masked_tree)} ({1-(len(masked_tree)/len(ivtree)):.2%}) loci removed"
)
コード例 #51
0
ファイル: make-xges.py プロジェクト: plugorgau/bbb-render
    def add_slides(self, with_annotations):
        layer = self._add_layer('Slides')
        doc = ET.parse(os.path.join(self.opts.basedir, 'shapes.svg'))
        slides = {}
        slide_time = IntervalTree()
        for img in doc.iterfind(
                './{http://www.w3.org/2000/svg}image[@class="slide"]'):
            info = SlideInfo(
                id=img.get('id'),
                width=int(img.get('width')),
                height=int(img.get('height')),
                start=round(float(img.get('in')) * Gst.SECOND),
                end=round(float(img.get('out')) * Gst.SECOND),
            )
            slides[info.id] = info
            slide_time.addi(info.start, info.end, info)

            # Don't bother creating an asset for out of range slides
            if info.end < self.start_time or info.start > self.end_time:
                continue

            path = img.get('{http://www.w3.org/1999/xlink}href')
            # If this is a "deskshare" slide, don't show anything
            if path.endswith('/deskshare.png'):
                continue

            asset = self._get_asset(os.path.join(self.opts.basedir, path))
            width, height = self._constrain(
                self._get_dimensions(asset),
                (self.slides_width, self.opts.height))
            self._add_clip(layer, asset, info.start, 0, info.end - info.start,
                           0, 0, width, height)

        # If we're not processing annotations, then we're done.
        if not with_annotations:
            return

        cursor_layer = self._add_layer('Cursor')
        # Move above the slides layer
        self.timeline.move_layer(cursor_layer, cursor_layer.get_priority() - 1)
        dot = self._get_asset('dot.png')
        dot_width, dot_height = self._get_dimensions(dot)
        cursor_doc = ET.parse(os.path.join(self.opts.basedir, 'cursor.xml'))
        events = []
        for event in cursor_doc.iterfind('./event'):
            x, y = event.find('./cursor').text.split()
            start = round(float(event.attrib['timestamp']) * Gst.SECOND)
            events.append(CursorEvent(float(x), float(y), start))

        for i, pos in enumerate(events):
            # negative positions are used to indicate that no cursor
            # should be displayed.
            if pos.x < 0 and pos.y < 0:
                continue

            # Show cursor until next event or if it is the last event,
            # the end of recording.
            if i + 1 < len(events):
                end = events[i + 1].start
            else:
                end = self.end_time

            # Find the width/height of the slide corresponding to this
            # point in time
            info = [i.data for i in slide_time.at(pos.start)][0]
            width, height = self._constrain(
                (info.width, info.height),
                (self.slides_width, self.opts.height))

            self._add_clip(cursor_layer, dot, pos.start, 0, end - pos.start,
                           round(width * pos.x - dot_width / 2),
                           round(height * pos.y - dot_height / 2), dot_width,
                           dot_height)

        layer = self._add_layer('Annotations')
        # Move above the slides layer
        self.timeline.move_layer(layer, layer.get_priority() - 1)
        for canvas in doc.iterfind(
                './{http://www.w3.org/2000/svg}g[@class="canvas"]'):
            info = slides[canvas.get('image')]
            t = IntervalTree()
            for index, shape in enumerate(
                    canvas.iterfind(
                        './{http://www.w3.org/2000/svg}g[@class="shape"]')):
                shape.set('style',
                          shape.get('style').replace('visibility:hidden;', ''))
                timestamp = round(float(shape.get('timestamp')) * Gst.SECOND)
                undo = round(float(shape.get('undo')) * Gst.SECOND)
                if undo < 0:
                    undo = info.end

                # Clip timestamps to slide visibility
                start = min(max(timestamp, info.start), info.end)
                end = min(max(undo, info.start), info.end)

                # Don't bother creating annotations for out of range times
                if end < self.start_time or start > self.end_time:
                    continue

                t.addi(start, end, [(index, shape)])

            t.split_overlaps()
            t.merge_overlaps(strict=True, data_reducer=operator.add)
            for index, interval in enumerate(sorted(t)):
                svg = ET.Element('{http://www.w3.org/2000/svg}svg')
                svg.set('version', '1.1')
                svg.set('width', '{}px'.format(info.width))
                svg.set('height', '{}px'.format(info.height))
                svg.set('viewBox', '0 0 {} {}'.format(info.width, info.height))

                # We want to discard all but the last version of each
                # shape ID, which requires two passes.
                shapes = sorted(interval.data)
                shape_index = {}
                for index, shape in shapes:
                    shape_index[shape.get('shape')] = index
                for index, shape in shapes:
                    if shape_index[shape.get('shape')] != index: continue
                    svg.append(shape)

                path = os.path.join(
                    self.opts.basedir,
                    'annotations-{}-{}.svg'.format(info.id, index))
                with open(path, 'wb') as fp:
                    fp.write(ET.tostring(svg, xml_declaration=True))

                asset = self._get_asset(path)
                width, height = self._constrain(
                    (info.width, info.height),
                    (self.slides_width, self.opts.height))
                self._add_clip(layer, asset, interval.begin, 0,
                               interval.end - interval.begin, 0, 0, width,
                               height)
コード例 #52
0
def test_overlaps_empty():
    # Empty tree
    t = IntervalTree()
    assert not t.overlaps(-1)
    assert not t.overlaps(0)

    assert not t.overlaps(-1, 1)
    assert not t.overlaps(-1, 0)
    assert not t.overlaps(0, 0)
    assert not t.overlaps(0, 1)
    assert not t.overlaps(1, 0)
    assert not t.overlaps(1, -1)
    assert not t.overlaps(0, -1)

    assert not t.overlaps(Interval(-1, 1))
    assert not t.overlaps(Interval(-1, 0))
    assert not t.overlaps(Interval(0, 0))
    assert not t.overlaps(Interval(0, 1))
    assert not t.overlaps(Interval(1, 0))
    assert not t.overlaps(Interval(1, -1))
    assert not t.overlaps(Interval(0, -1))
コード例 #53
0
ファイル: scheduler.py プロジェクト: a1d4r/event-scheduler
 def __init__(self, intervals: list[Interval]):
     self._tree = IntervalTree()
     for interval in intervals:
         self._tree.add(interval)
コード例 #54
0
class SegmentProducer(object):

    save_interval = SAVE_INTERVAL

    def __init__(self, download, n_procs):

        assert download.size is not None,\
            'Segment producer passed uninitizalied Download!'

        self.download = download
        self.n_procs = n_procs

        # Initialize producer
        self.load_state()
        self._setup_pbar()
        self._setup_queues()
        self._setup_work()
        self.schedule()

    def _setup_pbar(self):
        self.pbar = None
        self.pbar = get_pbar(self.download.ID, self.download.size)

    def _setup_work(self):
        if self.is_complete():
            log.info('File already complete.')
            return

        work_size = self.integrate(self.work_pool)
        self.block_size = work_size / self.n_procs

    def _setup_queues(self):
        if WINDOWS:
            self.q_work = Queue()
            self.q_complete = Queue()
        else:
            manager = Manager()
            self.q_work = manager.Queue()
            self.q_complete = manager.Queue()

    def integrate(self, itree):
        return sum([i.end-i.begin for i in itree.items()])

    def validate_segment_md5sums(self):
        if not self.download.check_segment_md5sums:
            return True
        corrupt_segments = 0
        intervals = sorted(self.completed.items())
        pbar = ProgressBar(widgets=[
            'Checksumming {}: '.format(self.download.ID), Percentage(), ' ',
            Bar(marker='#', left='[', right=']'), ' ', ETA()])
        with mmap_open(self.download.path) as data:
            for interval in pbar(intervals):
                log.debug('Checking segment md5: {}'.format(interval))
                if not interval.data or 'md5sum' not in interval.data:
                    log.error(STRIP(
                        """User opted to check segment md5sums on restart.
                        Previous download did not record segment
                        md5sums (--no-segment-md5sums)."""))
                    return
                chunk = data[interval.begin:interval.end]
                checksum = md5sum(chunk)
                if checksum != interval.data.get('md5sum'):
                    log.debug('Redownloading corrupt segment {}, {}.'.format(
                        interval, checksum))
                    corrupt_segments += 1
                    self.completed.remove(interval)
        if corrupt_segments:
            log.warn('Redownloading {} currupt segments.'.format(
                corrupt_segments))

    def load_state(self):
        # Establish default intervals
        self.work_pool = IntervalTree([Interval(0, self.download.size)])
        self.completed = IntervalTree()
        self.size_complete = 0
        if not os.path.isfile(self.download.state_path)\
           and os.path.isfile(self.download.path):
            log.warn(STRIP(
                """A file named '{} was found but no state file was found at at
                '{}'. Either this file was downloaded to a different
                location, the state file was moved, or the state file
                was deleted.  Parcel refuses to claim the file has
                been successfully downloaded and will restart the
                download.\n""").format(
                    self.download.path, self.download.state_path))
            return

        if not os.path.isfile(self.download.state_path):
            self.download.setup_file()
            return

        # If there is a file at load_path, attempt to remove
        # downloaded sections from work_pool
        log.info('Found state file {}, attempting to resume download'.format(
            self.download.state_path))

        if not os.path.isfile(self.download.path):
            log.warn(STRIP(
                """State file found at '{}' but no file for {}.
                Restarting entire download.""".format(
                    self.download.state_path, self.download.ID)))
            return
        try:
            with open(self.download.state_path, "rb") as f:
                self.completed = pickle.load(f)
            assert isinstance(self.completed, IntervalTree), \
                "Bad save state: {}".format(self.download.state_path)
        except Exception as e:
            self.completed = IntervalTree()
            log.error('Unable to resume file state: {}'.format(str(e)))
        else:
            self.validate_segment_md5sums()
            self.size_complete = self.integrate(self.completed)
            for interval in self.completed:
                self.work_pool.chop(interval.begin, interval.end)

    def save_state(self):
        try:
            # Grab a temp file in the same directory (hopefully avoud
            # cross device links) in order to atomically write our save file
            temp = tempfile.NamedTemporaryFile(
                prefix='.parcel_',
                dir=os.path.abspath(self.download.state_directory),
                delete=False)
            # Write completed state
            pickle.dump(self.completed, temp)
            # Make sure all data is written to disk
            temp.flush()
            os.fsync(temp.fileno())
            temp.close()

            # Rename temp file as our save file, this could fail if
            # the state file and the temp directory are on different devices
            if OS_WINDOWS and os.path.exists(self.download.state_path):
                # If we're on windows, there's not much we can do here
                # except stash the old state file, rename the new one,
                # and back up if there is a problem.
                old_path = os.path.join(tempfile.gettempdir(), ''.join(
                    random.choice(string.ascii_lowercase + string.digits)
                    for _ in range(10)))
                try:
                    # stash the old state file
                    os.rename(self.download.state_path, old_path)
                    # move the new state file into place
                    os.rename(temp.name, self.download.state_path)
                    # if no exception, then delete the old stash
                    os.remove(old_path)
                except Exception as msg:
                    log.error('Unable to write state file: {}'.format(msg))
                    try:
                        os.rename(old_path, self.download.state_path)
                    except:
                        pass
                    raise
            else:
                # If we're not on windows, then we'll just try to
                # atomically rename the file
                os.rename(temp.name, self.download.state_path)

        except KeyboardInterrupt:
            log.warn('Keyboard interrupt. removing temp save file'.format(
                temp.name))
            temp.close()
            os.remove(temp.name)
        except Exception as e:
            log.error('Unable to save state: {}'.format(str(e)))
            raise

    def schedule(self):
        while True:
            interval = self._get_next_interval()
            log.debug('Returning interval: {}'.format(interval))
            if not interval:
                return
            self.q_work.put(interval)

    def _get_next_interval(self):
        intervals = sorted(self.work_pool.items())
        if not intervals:
            return None
        interval = intervals[0]
        start = interval.begin
        end = min(interval.end, start + self.block_size)
        self.work_pool.chop(start, end)
        return Interval(start, end)

    def print_progress(self):
        if not self.pbar:
            return
        try:
            self.pbar.update(self.size_complete)
        except Exception as e:
            log.debug('Unable to update pbar: {}'.format(str(e)))

    def check_file_exists_and_size(self):
        if self.download.is_regular_file:
            return (os.path.isfile(self.download.path)
                    and os.path.getsize(
                        self.download.path) == self.download.size)
        else:
            log.debug('File is not a regular file, refusing to check size.')
            return (os.path.exists(self.download.path))

    def is_complete(self):
        return (self.integrate(self.completed) == self.download.size and
                self.check_file_exists_and_size())

    def finish_download(self):
        # Tell the children there is no more work, each child should
        # pull one NoneType from the queue and exit
        for i in range(self.n_procs):
            self.q_work.put(None)

        # Wait for all the children to exit by checking to make sure
        # that everyone has taken their NoneType from the queue.
        # Otherwise, the segment producer will exit before the
        # children return, causing them to read from a closed queue
        log.debug('Waiting for children to report')
        while not self.q_work.empty():
            time.sleep(0.1)

        # Finish the progressbar
        if self.pbar:
            self.pbar.finish()

    def wait_for_completion(self):
        try:
            since_save = 0
            while not self.is_complete():
                while since_save < self.save_interval:
                    interval = self.q_complete.get()
                    self.completed.add(interval)
                    if self.is_complete():
                        break
                    this_size = interval.end - interval.begin
                    self.size_complete += this_size
                    since_save += this_size
                    self.print_progress()
                since_save = 0
                self.save_state()
        finally:
            self.finish_download()
コード例 #55
0
"ChME":
	{"Dix":"Dixon-Blood_all_HindIII_40k.hm.IC_domains_40KB.jucebox_domains.annotation",
	"Arm":"Arm-Blood-all-HindIII-40k.hm.gzipped_matrix.jucebox_domains.annotation"},#TADtree""},
"ChIE":
	{"Dix":"Dixon-ChIE_all_HindIII_40k.hm.IC_domains_40KB.jucebox_domains.annotation",
	"Arm":"Arm-ChIE-all-HindIII-40k.hm.gzipped_matrix.jucebox_domains.annotation"},#TADtree""},
}

E1 = {
"ChEF":"/mnt/storage/home/vsfishman/HiC/tutorial_Fishman/chick/mapped-GalGal5filtered/GalGal5filteredChrmLevel/ChEF-all-HindIII-100k.hm.eig",
"ChME":"/mnt/storage/home/vsfishman/HiC/tutorial_Fishman/chick/mapped-GalGal5filtered/GalGal5filteredChrmLevel/Blood-all-HindIII-100k.hm.eig",
"ChIE":"/mnt/storage/home/vsfishman/HiC/tutorial_Fishman/chick/mapped-GalGal5filtered/GalGal5filteredChrmLevel/ChIE-all-HindIII-100k.hm.eig"
}
E1=prepE1values(E1)

domainsTree = IntervalTree()
for cell in Domains:
	for alg in Domains[cell]:
		getDomains(domainsTree,base_folder+Domains[cell][alg],cell+"_"+alg)

print len(domainsTree)
#check_tree(domainsTree)
#print len(domainsTree)
print "filter_chrm"
filter_chrm(domainsTree,["chr1","chr2","chr3","chr4","chr11","chr12","chr13","chr14","chr15"])
print len(domainsTree)

print "filter_only_borders_inside_TAD"
within = "ChEF"
distance_from_border = 300000
domainsTree=filter_only_borders_inside_TAD(domainsTree,[within+"_"+i for i in Domains[within].keys()],distance_from_border = distance_from_border)
コード例 #56
0
ファイル: scheduler.py プロジェクト: a1d4r/event-scheduler
class Scheduler:
    """Scheduler for event looking for the most suitable time."""
    def __init__(self, intervals: list[Interval]):
        self._tree = IntervalTree()
        for interval in intervals:
            self._tree.add(interval)

    def get_most_suitable_time_intervals(
            self,
            duration: timedelta,
            limit: Optional[int] = None) -> list[Interval]:
        """
        Get most suitable time intervals based on the number of active participants.
        Return not more than `limit` time intervals if specified.
        """
        # get interval boundaries
        boundaries = list(self._tree.boundary_table.keys())

        # check if tree is not empty
        if len(boundaries) < 2:
            return []
        result = []

        # find all intervals with length greater than duration
        # using two-pointers technique
        left, right = 0, 1
        while left < len(boundaries):
            # move right pointer until the interval has enough duration

            while (right < len(boundaries)
                   and boundaries[right] - boundaries[left] < duration):
                right += 1
            if (right == len(boundaries)
                    or boundaries[right] - boundaries[left] < duration):
                break

            # go through all intervals and intersect data (participants)
            participants = set.intersection(
                *({interval.data
                   for interval in self._tree[start:end]} for start, end in
                  zip(boundaries[left:right], boundaries[left + 1:right + 1])))
            result.append(
                Interval(boundaries[left], boundaries[right],
                         sorted(participants)))
            left += 1

        # sort by number of active participants
        result.sort(key=lambda t: len(t.data), reverse=True)

        if limit:
            result = result[:limit]

        return result

    @classmethod
    async def from_event(cls, event: Event) -> Scheduler:
        """Create an instance of Scheduler from tortoise event model."""
        intervals = []
        await event.fetch_related('timetables__time_intervals')
        async for timetable in event.timetables:
            async for time_interval in timetable.time_intervals:
                intervals.append(
                    Interval(
                        time_interval.start,
                        time_interval.end,
                        timetable.participant_name,
                    ))
        return Scheduler(intervals)
コード例 #57
0
def part2(nearby_tickets, my_ticket, all_trees, field_range):

    # part 2
    valid_tickets = get_valid_tickets(nearby_tickets, all_trees)

    valid_tickets_col = list(zip(*valid_tickets))

    range_count = len(field_range)
    condition_match = []

    for i in range(0, range_count, 2):
        tree_range = [(int(field_range[i][0]), int(field_range[i][1]) + 1),
                      (int(field_range[i + 1][0]),
                       int(field_range[i + 1][1]) + 1)]
        # print(tree_range)

        tree = IntervalTree.from_tuples(tree_range)

        result = list(
            map(lambda valid: all([len(tree[int(n)]) > 0 for n in valid]),
                valid_tickets_col))
        # print(result)
        # print(list(np.where(result)[0]))
        condition_match.append(list(np.where(result)[0]))

    sorted_condition_match = sorted(condition_match, key=len)
    sorted_index = list(map(condition_match.index, sorted_condition_match))

    print("conditions satisfied by each column: ", condition_match)
    print("conditions satisfied by each column, sorted by length: ",
          sorted_condition_match)
    print("positions in the original list with respect to the sorted list",
          sorted_index)

    # result = [sorted_condition_match[0]]
    # for i in range(len(sorted_condition_match) - 1):
    #     result.append(set(sorted_condition_match[i+1]) - set(sorted_condition_match[i]))
    # print(result)

    final_cond = []
    for possible_conditions in sorted_condition_match:
        for cond in possible_conditions:
            if cond not in final_cond:
                final_cond.append(cond)
                break

    print(len(field_range) / 2)
    print(len(final_cond))
    print(final_cond)

    column_conditions = list(zip(sorted_index, final_cond))
    # print(respective_i)
    conditions_list = list(
        map(lambda t: t[1], sorted(column_conditions, key=lambda x: x[0])))
    print(
        "columns in the condition order",
        list(map(lambda t: t[0], sorted(column_conditions,
                                        key=lambda x: x[1]))))
    print(
        "conditions in the column order",
        list(map(lambda t: t[1], sorted(column_conditions,
                                        key=lambda x: x[0]))))

    departure_cols = conditions_list[:6]
    print(departure_cols)
    numbers = list(map(lambda x: int(my_ticket[x]), departure_cols))
    print(numbers)

    print(math.prod(numbers))
コード例 #58
0
def test_sequence():
    t = IntervalTree()
    t.addi(860, 917, 1)
    t.verify()
    t.addi(860, 917, 2)
    t.verify()
    t.addi(860, 917, 3)
    t.verify()
    t.addi(860, 917, 4)
    t.verify()
    t.addi(871, 917, 1)
    t.verify()
    t.addi(871, 917, 2)
    t.verify()
    t.addi(871, 917, 3)  # Value inserted here
    t.verify()
    t.addi(961, 986, 1)
    t.verify()
    t.addi(1047, 1064, 1)
    t.verify()
    t.addi(1047, 1064, 2)
    t.verify()
    t.removei(961, 986, 1)
    t.verify()
    t.removei(871, 917, 3)  # Deleted here
    t.verify()
コード例 #59
0
# from alignment file from PBsim to generate overlap dict
from intervaltree import Interval, IntervalTree
from Bio import AlignIO

maf_file = "D:/Data/hmmer for pacbio/Pacbio simulate/Ecoli_Pacbio_simulate_20X.maf"
# maf_file = "C:/Users/dunan/Documents/GitHub/CSE836_HW/homework2/P4_30X_0001.maf"
tree = IntervalTree()
seq_dict = {}
for multiple_alignment in AlignIO.parse(maf_file, "maf"):
    multiple_alignment = list(multiple_alignment)
    id = multiple_alignment[1].id
    start = multiple_alignment[0].annotations["start"]
    end = start + multiple_alignment[0].annotations["size"]
    tree[start:end] = id
    seq_dict[id] = (start, end)

with open(
        "C:/Users/dunan/Documents/GitHub/CSE836_HW/homework2/Ecoli_20X_overlap.txt",
        "w") as fout:
    seq_list = list(seq_dict.keys())
    overlap_dict = {}
    for seq_id in seq_list:
        #print(seq_id)
        overlap_list = list(
            tree.search(seq_dict[seq_id][0], seq_dict[seq_id][1]))
        for overlap_rec in overlap_list:
            if overlap_rec.data != seq_id:
                target_id = overlap_rec.data
                x = range(seq_dict[seq_id][0], seq_dict[seq_id][1])
                y = range(overlap_rec.begin, overlap_rec.end)
                overlap_len = len(set(x) & set(y))
コード例 #60
0
def test_empty_queries():
    t = IntervalTree()
    e = set()

    assert len(t) == 0
    assert t.is_empty()
    assert t[3] == e
    assert t[4:6] == e
    assert t.begin() == 0
    assert t.end() == 0
    assert t[t.begin():t.end()] == e
    assert t.overlap(t.begin(), t.end()) == e
    assert t.envelop(t.begin(), t.end()) == e
    assert t.items() == e
    assert set(t) == e
    assert set(t.copy()) == e
    assert t.find_nested() == {}
    assert t.range().is_null()
    assert t.range().length() == 0
    t.verify()