示例#1
0
def load_gtf(gtf_path):
    """
    Load a GTF annotation and create an index using IntervalTrees.

    Args:
        gtf_path: Path to the GTF file to load.

    Returns:
        Dictionary containing IntervalTree indexes of the annotation.
    """

    gtf_index = defaultdict()
    with open(gtf_path) as gtf_file:
        for line in gtf_file:
            if not line.startswith("#"):
                entry = line.split("\t")
                entry_addition = entry[8]
                entry_addition = entry_addition.split(";")
                entry_addition = entry_addition[0].split(" ")
                gene_id = entry_addition[1]

                feature = entry[2]
                #TYPE(Gene, exon etc.), START, END, STRAND, gene_ID
                info = [feature, entry[3], entry[4], entry[6], gene_id]

                #Build GTF INDEX
                if feature != "" and entry[3] != entry[4]:
                    if entry[0] in gtf_index:
                        index = gtf_index[entry[0]]
                    else:
                        index = IntervalTree()
                    index.addi(int(info[1]), int(info[2]), info)
                    gtf_index[entry[0]] = index

    return gtf_index
示例#2
0
def load_gtf(gtf_path):
    """
    Load a GTF annotation and create an index using IntervalTrees.

    Args:
        gtf_path: Path to the GTF file to load.

    Returns:
        Dictionary containing IntervalTree indexes of the annotation.
    """

    gtf_index = defaultdict()
    with open(gtf_path) as gtf_file:
        for line in gtf_file:
            if not line.startswith("#"):
                entry = line.split("\t")
                entry_addition = entry[8]
                entry_addition = entry_addition.split(";")
                entry_addition = entry_addition[0].split(" ")
                gene_id = entry_addition[1]

                feature = entry[2]
                #TYPE(Gene, exon etc.), START, END, STRAND, gene_ID
                info = [feature, entry[3], entry[4], entry[6], gene_id]

                #Build GTF INDEX
                if feature != "" and entry[3] != entry[4]:
                    if entry[0] in gtf_index:
                        index = gtf_index[entry[0]]
                    else:
                        index = IntervalTree()
                    index.addi(int(info[1]), int(info[2]), info)
                    gtf_index[entry[0]] = index

    return gtf_index
 def test_dedault_data(self):
     """
     Test if the function return correct market index list when given end date  < available end date
     """
     start = '2015-01-01'
     end = [
         '2015-01-17', '2015-02-18', '2015-03-19', '2015-03-21',
         '2015-04-22', '2015-04-24', '2015-04-25'
     ]
     s = datetime.strptime(start, '%Y-%m-%d').date()
     raised_fx = [1, 3, 40, 2, 5, 5, 20]
     score = [0.01, 0.004, 0.05, 0.12, 0.5, 0.003, 0.2]
     dtTree = IntervalTree()
     for i in range(0, len(end)):
         e = datetime.strptime(end[i], '%Y-%m-%d').date()
         dtTree.addi(s, e, [raised_fx[i], score[i]])
     timeSeg = [
         '2015-01-01', '2015-01-22', '2015-02-22', '2015-03-22',
         '2015-04-22'
     ]
     # compute expected market index
     mIndex = [0] * 5
     mIndex[0] = 0  # No campaigns were available in this time segment
     mIndex[1] = 1000  # default index for the first time point
     mIndex[2] = mean([1, 3]) / mean([1]) * 1000
     mIndex[3] = mean([1, 3, 40, 2]) / mean([1, 3]) * 1000
     mIndex[4] = mean([1, 3, 40, 2, 5]) / mean([1, 3, 40, 2]) * 1000
     error = 1e-7
     result = functest.getMarketIndex(dtTree, timeSeg)
     for i in range(0, 4):
         self.assertTrue(abs(mIndex[i] - result[i]) < error)
示例#4
0
def mapmut_and_filter(clusters_tree, mutations_in, cluster_mutations_cutoff):
    """
    Get the number of mutations within a cluster, remove those clusters below cutoff mutations

    Args:
        clusters_tree (IntervalTree): genomic regions are intervals, data are merged clusters (dict of dict)
        mutations_in (list): list of mutations fitting in regions
        cluster_mutations_cutoff (int): number of cluster mutations cutoff

    Returns:
        filter_clusters_tree (IntervalTree): genomic regions are intervals, data are filtered clusters (dict of dict)
    """
    filter_clusters_tree = IntervalTree()

    # Iterate through all regions
    for interval in clusters_tree:
        clusters = interval.data.copy()
        for cluster, values in interval.data.items():
            left = values['left_m'][1]
            right = values['right_m'][1]
            # Search mutations
            cluster_muts = [i for i in mutations_in if left <= i.position <= right]
            cluster_samples = set()
            for mut in cluster_muts:
                sample = mut.sample
                cluster_samples.add(sample)
            if len(cluster_muts) >= cluster_mutations_cutoff:
                clusters[cluster]['mutations'] = cluster_muts
                clusters[cluster]['samples'] = cluster_samples
                clusters[cluster]['fra_uniq_samples'] = len(cluster_samples)/len(cluster_muts)
            else:
                del clusters[cluster]
        filter_clusters_tree.addi(interval[0], interval[1], clusters)

    return filter_clusters_tree
示例#5
0
def load_GTF(gtf_file):

    gtf_index = defaultdict()
    with open(gtf_file) as f:
        for line in f:
             if (not line.startswith("#")):
                 entry = line.split("\t")
                 entry_addition = entry[8]
                 entry_addition = entry_addition.split(";")
                 entry_addition = entry_addition[0].split(" ")
                 gene_id = entry_addition[1]
               
                 type = entry[2] 
                 #TYPE(Gene, exon etc.), START, END, STRAND, gene_ID
                 info = [type, entry[3], entry[4], entry[6], gene_id]
        
                 #Build GTF INDEX
                 if (type != "" and entry[3]!= entry[4]):
                    index = IntervalTree()
                    if (entry[0] in gtf_index):
                         index = gtf_index[entry[0]]
                    index.addi(int(info[1]),int(info[2]),info) 
                    gtf_index[entry[0]] = index

    return (gtf_index)
示例#6
0
def get_multilines(spans):
    intervals = Intervals()
    lines = []
    for start, stop, type in spans:
        line = Line(start, stop, type, level=None)
        intervals.addi(start, stop, line)
        lines.append(line)

    # level
    for line in lines:
        selected = intervals.search(line.start, line.stop)
        line.level = get_free_level(selected)

    # chunk
    intervals.split_overlaps()

    # group
    groups = defaultdict(list)
    for start, stop, line in intervals:
        groups[start, stop].append(line)

    for start, stop in sorted(groups):
        lines = groups[start, stop]
        lines = sorted(lines, key=lambda _: _.level)
        yield Multiline(start, stop, lines)
 def index_gene_annotation_interval_tree(self):
     for chrm in self.m_gene_annotation:
         interval_tree = IntervalTree()
         for start_pos in self.m_gene_annotation[chrm]:
             end_pos = self.m_gene_annotation[chrm][start_pos][0][0]
             interval_tree.addi(start_pos, end_pos)
         self.m_interval_tree[chrm] = interval_tree
class MemoryMappedIo:
    """
    Holds statistics of all used address spaces.
    It uses an IntervalTree where an Interval refers
    to an instance of AddressSpaceStatistic.
    """

    def __init__(self):
        self._address_spaces = IntervalTree()
        self._number_of_accesses = 0

    def add_mapped_space(self, location, space, timestamp, trace):
        if space:
            self._address_spaces.addi(space.Address,
                                      space.Address + space.Size,
                                      AddressSpaceStatistic(space,
                                                            trace,
                                                            timestamp))

    def add_access(self, event, location):
        self._number_of_accesses += 1
        intervals = self._address_spaces[int(event.value)]
        assert(len(intervals) < 2)
        if len(intervals) == 1:
            space_stats = intervals.pop().data
            space_stats.inc_metric(location, event.metric.member.name, event.time)

    def __str__(self):
        out = ""
        for interval in self._address_spaces:
            out += "{}\n\n".format(interval.data)
        return out
示例#9
0
 def index_rmsk_annotation_interval_tree(self):
     for chrm in self.m_rmsk_annotation:
         interval_tree = IntervalTree()
         for pos in self.m_rmsk_annotation[chrm]:
             end_pos = self.m_rmsk_annotation[chrm][pos][0][0]
             interval_tree.addi(pos, end_pos)
         self.m_interval_tree[chrm] = interval_tree
示例#10
0
  def get_merged_variants(self, variants, key=None):
    # type: (List[vcfio.Variant], str) -> Iterable[vcfio.Variant]
    non_variant_tree = IntervalTree()
    grouped_variants = collections.defaultdict(list)
    for v in variants:
      self._align_with_window(v, key)
      if self._is_non_variant(v):
        non_variant_tree.addi(v.start, v.end, v)
      else:
        group_key = next(self._move_to_calls.get_merge_keys(v))
        grouped_variants[group_key].append(v)

    non_variants = self._merge_non_variants(non_variant_tree)
    variants = self._merge_variants(grouped_variants)

    non_variant_tree.clear()
    for nv in non_variants:
      non_variant_tree.addi(nv.start, nv.end, nv)

    splits = IntervalTree()
    for v in variants:
      non_variant_interval = non_variant_tree.search(v.start, v.end)
      if non_variant_interval:
        non_variant = next(iter(non_variant_interval)).data
        v.calls.extend(non_variant.calls)
        v.calls = sorted(v.calls)
        self._update_splits(splits, v)
      yield v

    for non_variant in self._split_non_variants(non_variant_tree, splits):
      yield non_variant
示例#11
0
def load_GTF(gtf_file):

    gtf_index = defaultdict()
    with open(gtf_file) as f:
        for line in f:
            if (not line.startswith("#")):
                entry = line.split("\t")
                entry_addition = entry[8]
                entry_addition = entry_addition.split(";")
                entry_addition = entry_addition[0].split(" ")
                gene_id = entry_addition[1]

                type = entry[2]
                #TYPE(Gene, exon etc.), START, END, STRAND, gene_ID
                info = [type, entry[3], entry[4], entry[6], gene_id]

                #Build GTF INDEX
                if (type != "" and entry[3] != entry[4]):
                    index = IntervalTree()
                    if (entry[0] in gtf_index):
                        index = gtf_index[entry[0]]
                    index.addi(int(info[1]), int(info[2]), info)
                    gtf_index[entry[0]] = index

    return (gtf_index)
示例#12
0
def aln_coverage(aln_list):
    """
    Calculate the coverage across the reported alignments for a given read. This will most
    often involve only a single alignment, but also considers non-overlapping alignments
    reported by BWA MEM scavenged from the XP tag. Reports the number of bases covered
    (<=read_len) and the overlap between them (normally 0).
    :param aln_list: the list of alignments for a read
    :return: dict {coverage: xx, overlap: yy}
    """
    # using an intervaltree for this
    tr = IntervalTree()
    tot = 0
    for ti in aln_list:
        if ti['is_reverse']:
            # reversed reads must be tallied from the opposite end
            n = ti['total']
            for op, nb in ti['cigartuple']:
                if op == 0:
                    tr.addi(n - nb, n)
                    tot += nb
                n -= nb
        else:
            # forward mapped reads tally from start position
            n = 0
            for op, nb in ti['cigartuple']:
                if op == 0:
                    tr.addi(n, n + nb)
                    tot += nb
                n += nb
    # lazy means of merging intervals
    tr.merge_overlaps()
    cov = sum([i.end - i.begin for i in tr])
    return {'coverage': cov, 'overlap': tot - cov, 'has_multi': len(aln_list) > 1}
示例#13
0
class Sequencer:
    sortkey = lambda n: n.start + n.length

    def __init__(self):
        self.notes = IntervalTree()

    def add(self, note):
        self.notes.addi(note.start, note.start + note.length, note)

    def remove(self, note):
        self.notes.removei(note.start, note.start + note.length, note)

    def length(self):
        return self.notes.end()

    def sample_at(self, t):

        # again, bad
        current = self.notes.at(t)

        acc = 0
        for note in current:
            note_pos = t - note.begin
            acc += (osc.sine(note_pos, note.data.pitch) * note.data.velocity *
                    adsr(note_pos, note.end - note.begin)) * (1 / len(current))

        return acc
示例#14
0
def cids_to_blocks(cid_tree):
    """
    Using an IntervalTree as returned by generate_random_cids(), create a new IntervalTree where now the
    intervals represent regions (blocks) of homogeneous effect. That is, each resulting interval defines a
    region where a fixed set of CIDs are involved.

    Blocks, therefore, do not overlap but are instead perfectly adjacent (zero spacing). For a given block
    the independent CID probabilities are normalized to sum to 1, in preparation of selection by random draw.

    :param cid_tree: an IntervalTree representing CIDs and the chromsome.
    :return: an IntervalTree of the homogeneous blocks for this set of CID.
    """

    # Get all the begin and end points in ascending order.
    # As they mark where a CID either begins are ends, each therefore
    # marks the end of one block and the beginning of another.
    x = []
    for inv in cid_tree:
        x.append(inv.begin), x.append(inv.end)
    x = np.unique(x)

    # interate over the CID coords, making all the block intervals.
    block_tree = IntervalTree()
    for i in xrange(len(x) - 1):
        ovl_invs = sorted(
            cid_tree[x[i]:x[i + 1]])  # the CIDs involved in this range

        # normalize probs for the block.
        p = np.fromiter((inv.data['prob'] for inv in ovl_invs), dtype=float)
        p /= p.sum()

        # a block stores the normalized probabilities and originating CID intervals for quick lookup.
        block_tree.addi(x[i], x[i + 1], {'prob_list': p, 'inv_list': ovl_invs})

    return block_tree
示例#15
0
def test_insert():
    tree = IntervalTree()

    tree[0:1] = "data"
    assert len(tree) == 1
    assert tree.items() == set([Interval(0, 1, "data")])

    tree.add(Interval(10, 20))
    assert len(tree) == 2
    assert tree.items() == set([Interval(0, 1, "data"), Interval(10, 20)])

    tree.addi(19.9, 20)
    assert len(tree) == 3
    assert tree.items() == set([
        Interval(0, 1, "data"),
        Interval(19.9, 20),
        Interval(10, 20),
    ])

    tree.update([Interval(19.9, 20.1), Interval(20.1, 30)])
    assert len(tree) == 5
    assert tree.items() == set([
        Interval(0, 1, "data"),
        Interval(19.9, 20),
        Interval(10, 20),
        Interval(19.9, 20.1),
        Interval(20.1, 30),
    ])
示例#16
0
def test_insert():
    tree = IntervalTree()

    tree[0:1] = "data"
    assert len(tree) == 1
    assert tree.items() == set([Interval(0, 1, "data")])

    tree.add(Interval(10, 20))
    assert len(tree) == 2
    assert tree.items() == set([Interval(0, 1, "data"), Interval(10, 20)])

    tree.addi(19.9, 20)
    assert len(tree) == 3
    assert tree.items() == set([
        Interval(0, 1, "data"),
        Interval(19.9, 20),
        Interval(10, 20),
    ])

    tree.update([Interval(19.9, 20.1), Interval(20.1, 30)])
    assert len(tree) == 5
    assert tree.items() == set([
        Interval(0, 1, "data"),
        Interval(19.9, 20),
        Interval(10, 20),
        Interval(19.9, 20.1),
        Interval(20.1, 30),
    ])
示例#17
0
def test_add_invalid_interval():
    """
    Ensure that begin < end.
    """
    itree = IntervalTree()
    with pytest.raises(ValueError):
        itree.addi(1, 0)

    with pytest.raises(ValueError):
        itree.addi(1, 1)

    with pytest.raises(ValueError):
        itree[1:0] = "value"

    with pytest.raises(ValueError):
        itree[1:1] = "value"

    with pytest.raises(ValueError):
        itree[1.1:1.05] = "value"

    with pytest.raises(ValueError):
        itree[1.1:1.1] = "value"

    with pytest.raises(ValueError):
        itree.extend([Interval(1, 0)])

    with pytest.raises(ValueError):
        itree.extend([Interval(1, 1)])
示例#18
0
 def get_length(self):
     gene_tree = IntervalTree()
     for t in self.transcript.values():
         for e in t.exon:
             gene_tree.addi(e[0], e[1])
     gene_tree.merge_overlaps()
     return sum(x.end - x.begin + 1 for x in gene_tree)
class RepeatDb(object):
    def __init__(self, assembly, contig, start, end):
        """
        Given a range on a contig, get all the repeats overlapping that range.
        
        Keeps an IntervalTree of element names, and a Counter from element
        name to number of that element in the range.
        
        No protection against SQL injection.
        
        """

        # Make the interval tree
        self.tree = IntervalTree()

        # Make a counter for repeats with a certain name
        self.counts = collections.Counter()

        command = [
            "hgsql", "-e", "select repName, genoName, genoStart, genoEnd "
            "from {}.rmsk where genoName = '{}' and genoStart > '{}' "
            "and genoEnd < '{}';".format(assembly, contig, start, end)
        ]
        process = subprocess.Popen(command, stdout=subprocess.PIPE)

        for parts in itertools.islice(tsv.TsvReader(process.stdout), 1, None):
            # For each line except the first, broken into fields

            # Add the item to the tree covering its range. Store the repeat type
            # name as the interval's data.
            self.tree.addi(int(parts[2]), int(parts[3]), parts[0])

            # Count it
            self.counts[parts[0]] += 1

    def get_copies(self, contig, pos):
        """
        Given a contig name and a position, estimate the copy number of that
        position in the genome.
        
        Return the number of instances expected (1 for non-repetitive sequence).
        """

        # TODO: use contig

        # Get the set of overlapping things
        overlaps = self.tree[pos]

        # Keep track of the number of copies of the most numerous repeat
        # observed.
        max_copies = 1

        for interval in overlaps:
            # For each repeat we are in

            # Max in how many copies of it exist
            max_copies = max(max_copies, self.counts[interval.data])

        return max_copies
示例#20
0
def find_candidate(Interval_list,
                   window=10,
                   min_primary=0,
                   min_support=0,
                   secondary_thres=0.0,
                   primary_thres=1.0):
    '''
    Find candidate exon boundary (i.e. intron boundary) within a given range.
    Parameter:
        begin:
            start (left-most) position of the range to be searched (0-based)
        end:
            end (right-most) possition of the range to be searched (0-based)
        tree:
            IntervalTree containing all boundary pairs 
        window: 
            window size for group surrounding boundaries (difference 
            of boundary in each size of the intron will be grouped together if 
            the absolute difference < window size)
        min_support:
            The best supported boundary need will be included only when the num
            of support reaches the minimum
        secondary_thres:
            only the junctions with multiple well supported boundary will
            be included. Well supported junction is defined as 
            secondary_thres * support num of the most supported boundary.
    '''
    # get boundaries with in searching window, sorted by the number of support
    intervals_tree = IntervalTree()
    for interval in Interval_list:
        intervals_tree.addi(interval.begin, interval.end, interval.data)

    candidate_boundaries = []
    while intervals_tree:
        interval = max(intervals_tree, key=lambda x: x.data)
        best_support = interval.data
        if interval.data < min_primary:  # lower bound of the support
            return candidate_boundaries

        #candidate_boundaries.append(interval)
        intervals_tree.remove(interval)

        # include surrounding boundaries
        enveloped_interval = intervals_tree.envelop(interval.begin - window,
                                                    interval.end + window)
        neighbour_found = []
        for i in enveloped_interval:
            if i.begin <= interval.begin + window and \
                    i.end >= interval.end - window:
                if i.data > secondary_thres * best_support:
                    neighbour_found.append((interval, i))
                intervals_tree.remove(i)
        if neighbour_found:
            neighbour_found.append((interval, interval))
            count = sum([x.data for y, x in neighbour_found])
            if count >= min_support and best_support / count <= primary_thres:
                candidate_boundaries += neighbour_found
    return candidate_boundaries
class RepeatDb(object):

    def __init__(self, assembly, contig, start, end):
        """
        Given a range on a contig, get all the repeats overlapping that range.
        
        Keeps an IntervalTree of element names, and a Counter from element
        name to number of that element in the range.
        
        No protection against SQL injection.
        
        """
        
        # Make the interval tree
        self.tree = IntervalTree()
        
        # Make a counter for repeats with a certain name
        self.counts = collections.Counter()
        
        command = ["hgsql", "-e", "select repName, genoName, genoStart, genoEnd "
            "from {}.rmsk where genoName = '{}' and genoStart > '{}' "
            "and genoEnd < '{}';".format(assembly, contig, start, end)]
        process = subprocess.Popen(command, stdout=subprocess.PIPE)
        
        for parts in itertools.islice(tsv.TsvReader(process.stdout), 1, None):
            # For each line except the first, broken into fields
            
            # Add the item to the tree covering its range. Store the repeat type
            # name as the interval's data.
            self.tree.addi(int(parts[2]), int(parts[3]), parts[0])
            
            # Count it
            self.counts[parts[0]] += 1

    def get_copies(self, contig, pos):
        """
        Given a contig name and a position, estimate the copy number of that
        position in the genome.
        
        Return the number of instances expected (1 for non-repetitive sequence).
        """
        
        # TODO: use contig
        
        # Get the set of overlapping things
        overlaps = self.tree[pos]
        
        # Keep track of the number of copies of the most numerous repeat
        # observed.
        max_copies = 1
        
        for interval in overlaps:
            # For each repeat we are in
            
            # Max in how many copies of it exist
            max_copies = max(max_copies, self.counts[interval.data])
            
        return max_copies
示例#22
0
def getDataTree(df):
    tree = IntervalTree()
    for row in df.itertuples():
        start = row[1]
        end = row[2]
        raised_money = row[3]
        score = row[5]
        tree.addi(start, end, [raised_money,score])
    return tree
示例#23
0
    def countIdealOverlaps(self, nodes):
        iTree = IntervalTree()
        for node in nodes:
            iTree.addi(node.idealLeft(), node.idealRight(), data=node)

        for node in nodes:
            overlaps = iTree.search(node.idealLeft(), node.idealRight())
            node.overlaps = [x.data for x in overlaps]
            node.overlapCount = len(overlaps)
示例#24
0
def test_adding_speed():
    base_tree = IntervalTree()
    l_bound = 0
    u_bound = 10000
    random.seed(10)
    for i in range(1000000):
        start = random.randint(l_bound, u_bound - 1)
        end = random.randint(start + 1, u_bound)
        base_tree.addi(start, end)
示例#25
0
class FlashReaderContext(DebugContext):
    """! @brief Reads flash memory regions from an ELF file instead of the target."""

    def __init__(self, parent, elf):
        super(FlashReaderContext, self).__init__(parent)
        self._elf = elf

        self._build_regions()

    def _build_regions(self):
        self._tree = IntervalTree()
        for sect in [s for s in self._elf.sections if (s.region and s.region.is_flash)]:
            start = sect.start
            length = sect.length
            sect.data # Go ahead and read the data from the file.
            self._tree.addi(start, start + length, sect)
            LOG.debug("created flash section [%x:%x] for section %s", start, start + length, sect.name)

    def read_memory(self, addr, transfer_size=32, now=True):
        length = transfer_size // 8
        matches = self._tree.overlap(addr, addr + length)
        # Must match only one interval (ELF section).
        if len(matches) != 1:
            return self._parent.read_memory(addr, transfer_size, now)
        section = matches.pop().data
        addr -= section.start

        def read_memory_cb():
            LOG.debug("read flash data [%x:%x] from section %s", section.start + addr, section.start + addr  + length, section.name)
            data = section.data[addr:addr + length]
            if transfer_size == 8:
                return data[0]
            elif transfer_size == 16:
                return conversion.byte_list_to_u16le_list(data)[0]
            elif transfer_size == 32:
                return conversion.byte_list_to_u32le_list(data)[0]
            else:
                raise ValueError("invalid transfer_size (%d)" % transfer_size)

        if now:
            return read_memory_cb()
        else:
            return read_memory_cb

    def read_memory_block8(self, addr, size):
        matches = self._tree.overlap(addr, addr + size)
        # Must match only one interval (ELF section).
        if len(matches) != 1:
            return self._parent.read_memory_block8(addr, size)
        section = matches.pop().data
        addr -= section.start
        data = section.data[addr:addr + size]
        LOG.debug("read flash data [%x:%x]", section.start + addr, section.start + addr  + size)
        return list(data)

    def read_memory_block32(self, addr, size):
        return conversion.byte_list_to_u32le_list(self.read_memory_block8(addr, size))
class BorderModel(QObject):
    rangeChanged = pyqtSignal([BorderedRange])

    def __init__(self, parent, color_theme=SolarizedColorTheme):
        super(BorderModel, self).__init__(parent)

        # data structure description:
        # _db is an interval tree that indexes on the start and end of bordered ranges
        # the values are BorderedRange instances.
        # given an index, determining its border is):
        #   intervaltree lookup index in _db (which is O(log <num ranges>) )
        #   iterate containing ranges (worst case, O(<num ranges>), but typically small)
        #     hash lookup on index to fetch border state (which is O(1))
        self._db = IntervalTree()
        self._theme = color_theme

    def border_region(self, begin, end, color=None):
        if color is None:
            color = self._theme.get_accent(len(self._db))
        range = BorderedRange(begin, end, BorderTheme(color), compute_region_border(begin, end))
        # note we use (end + 1) to ensure the entire selection gets captured
        self._db.addi(range.begin, range.end + 1, range)
        self.rangeChanged.emit(range)

    def clear_region(self, begin, end):
        span = end - begin
        to_remove = []
        for r in self._db[begin:end]:
            if r.end - r.begin - 1 == span:
                to_remove.append(r)
        for r in to_remove:
            self._db.removei(r.begin, r.end, r.data)
            self.rangeChanged.emit(r.data)

    def get_border(self, index):
        # ranges is a (potentially empty) list of intervaltree.Interval instances
        # we sort them here from shorted length to longest, because we want
        #    the most specific border
        ranges = sorted(self._db[index], key=lambda r: r.end - r.begin)
        if len(ranges) > 0:
            range = ranges[0].data
            cell = range.cells.get(index, None)
            if cell is None:
                return None
            ret = BorderData(cell.top, cell.bottom, cell.left, cell.right, range.theme)
            return ret
        return None

    def is_index_bordered(self, index):
        return len(self._db[index]) > 0

    def is_region_bordered(self, begin, end):
        span = end - begin
        for range in self._db[begin:end]:
            if range.end - range.begin == span:
                return True
        return False
示例#27
0
    def countIdealOverlaps(self, nodes):
        iTree = IntervalTree()
        for node in nodes:
            iTree.addi(node.idealLeft(), node.idealRight(), data=node)

        for node in nodes:
            overlaps = iTree.overlap(node.idealLeft(), node.idealRight())
            node.overlaps = [x.data for x in overlaps]
            node.overlapCount = len(overlaps)
示例#28
0
def plot2C2AScatterTimeSeries(zmwFixture, frameInterval=4096):
    """
    Plot a 2C2A scatter plot for every `frameInterval` frames; overlay
    information about HQRegion and alignment(s), if found in the dataset.
    """
    t = zmwFixture.cameraTrace
    df = pd.DataFrame(np.transpose(t), columns=["C1", "C2"])

    # what is the extent of the data?  force a square perspective so
    # we don't distort the spectral angle
    xmin = ymin = min(df.min())
    xmax = ymax = max(df.max())

    def fracX(frac): return xmin + (xmax - xmin) * frac
    def fracY(frac): return ymin + (ymax - ymin) * frac

    numPanes = int(math.ceil(float(zmwFixture.numFrames) / frameInterval))
    numCols = 6
    numRows = int(math.ceil(float(numPanes) / numCols))
    paneSize = np.array([3, 3])

    figsize = np.array([numCols, numRows]) * paneSize
    fig, ax = plt.subplots(numRows, numCols, sharex=True, sharey=True,
                           figsize=figsize)
    axr = ax.ravel()

    details = "" # TODO
    fig.suptitle("%s\n%s" % (zmwFixture.zmwName, details), fontsize=20)


    alnIntervals = IntervalTree()
    for r in zmwFixture.regions:
        if r.regionType == Region.ALIGNMENT_REGION:
            alnIntervals.addi(r.startFrame, r.endFrame)

    def overlapsAln(frameStart, frameEnd):
        if alnIntervals.search(frameStart, frameEnd):
            return True
        else:
            return False

    for i in xrange(numPanes):
        frameSpan = startFrame, endFrame = i*frameInterval, (i+1)*frameInterval
        axr[i].set_xlim(xmin, xmax)
        axr[i].set_ylim(ymin, ymax)
        axr[i].plot(df.C1[startFrame:endFrame], df.C2[startFrame:endFrame], ".")

        baseSpan = zmwFixture.baseIntervalFromFrames(*frameSpan)
        axr[i].text(fracX(0.6), fracY(0.9), "/%d_%d" %  baseSpan)

        if overlapsAln(*frameSpan):
            axr[i].hlines(fracY(1.0), xmin, xmax, colors=["red"], linewidth=4)


    return axr
示例#29
0
def sorted_complement(tree, start=None, end=None) -> IntervalTree:
    result = IntervalTree()
    if start is None:
        start = tree.begin()
    if end is None:
        end = tree.end()

    result.addi(start, end)  # using input tree bounds
    for iv in tree:
        result.chop(iv[0], iv[1])
    return sorted(result)
示例#30
0
def get_gene_lookup(tx_ref_file):
    '''
    Generate start/end coordinate reference
    for genes and output as an interval tree
    dictionary. Also output dataframe containing
    chromosome, start and ends for all exons.
    '''
    ref_trees, ex_ref_out = None, None
    if tx_ref_file == '':
        return ref_trees, ex_trees, ex_ref_out

    logging.info('Generating lookup for genes...')
    #TODO: standardise with make_supertranscript for gtf handling
    tx_ref = pd.read_csv(tx_ref_file, comment='#', sep='\t', header=None, low_memory=False)
    tx_ref['gene_id'] = tx_ref[8].apply(lambda x: get_attribute(x, 'gene_id'))
    tx_ref['gene'] = tx_ref[8].apply(lambda x: get_attribute(x, 'gene_name'))

    # create start/end gene lookup, grouping adjacent rows
    # (this prevents merging distant genes with the same IDs)
    gn_ref = tx_ref[[0, 3, 4, 'gene_id', 'gene']]
    gn_ref.columns = ['chrom', 'start', 'end', 'gene_id', 'gene']
    adj_check = (gn_ref.gene_id != gn_ref.gene_id.shift()).cumsum()
    gn_ref = gn_ref.groupby(['chrom', 'gene_id', 'gene', adj_check],
                            as_index=False, sort=False).agg({'start': min, 'end': max})
    gn_ref = gn_ref.drop_duplicates()

    # start/end coordinates for gene matching
    ref_trees = {}
    chroms = np.unique(gn_ref.chrom.values)
    for chrom in chroms:
        chr_ref = gn_ref[gn_ref.chrom == chrom].drop_duplicates()
        ref_tree = IntervalTree()
        for s,e,g in zip(chr_ref['start'].values, chr_ref['end'].values, chr_ref['gene'].values):
            if g != '':
                ref_tree.addi(s-1, e, g)
        ref_trees[chrom] = ref_tree

    # merged exon boundaries for block annotation
    ex_ref = tx_ref[tx_ref[2] == 'exon']
    ex_ref_out = pd.DataFrame()
    ex_trees = {}
    for chrom in chroms:
        chr_ref = ex_ref[ex_ref[0] == chrom].drop_duplicates()
        ex_tree = IntervalTree()
        for s,e in zip(chr_ref[3].values, chr_ref[4].values):
            ex_tree.addi(s-1, e)
        ex_tree.merge_overlaps()
        tmp = pd.DataFrame([(chrom, tree[0], tree[1]) for tree in ex_tree],
                           columns=['chrom', 'start', 'end'])
        ex_ref_out = pd.concat([ex_ref_out, tmp], ignore_index=True)
        ex_trees[chrom] = ex_tree

    return ref_trees, ex_trees, ex_ref_out
示例#31
0
def find_stretches(alignments, character):
    '''Finds occurrences of a character in an alignment and builds up
    an interval tree from their start and stop positions.\n
    Returns the interval tree.
    '''
    tree = IntervalTree()
    for sequence in alignments:
        # print(sequence.seq)
        find_this = re.compile(r"{}+".format(character))
        for m in re.finditer(find_this, str(sequence.seq)):
            tree.addi(m.start(), m.end(), sequence.id)
    return(tree)
class ColorModel(QObject):
    rangeChanged = pyqtSignal([ColoredRange])

    def __init__(self, parent, color_theme=SolarizedColorTheme):
        super(ColorModel, self).__init__(parent)
        self._db = IntervalTree()
        self._theme = color_theme

    def color_region(self, begin, end, color=None):
        if color is None:
            color = self._theme.get_accent(len(self._db))
        r = ColoredRange(begin, end, color)
        self.color_range(r)
        return r

    def clear_region(self, begin, end):
        span = end - begin
        to_remove = []
        for r in self._db[begin:end]:
            if r.end - r.begin == span:
                to_remove.append(r)
        for r in to_remove:
            self.clear_range(r.data)

    def color_range(self, range_):
        self._db.addi(range_.begin, range_.end, range_)
        self.rangeChanged.emit(range_)

    def clear_range(self, range_):
        self._db.removei(range_.begin, range_.end, range_)
        self.rangeChanged.emit(range_)

    def get_color(self, index):
        # ranges is a (potentially empty) list of intervaltree.Interval instances
        # we sort them here from shorted length to longest, because we want
        #    the most specific color
        ranges = sorted(self._db[index], key=lambda r: r.end - r.begin)
        if len(ranges) > 0:
            return ranges[0].data.color
        return None

    def get_region_colors(self, begin, end):
        if begin == end:
            results = self._db[begin]
        else:
            results = self._db[begin:end]
        return funcy.pluck_attr("data", results)

    def is_index_colored(self, index):
        return len(self._db[index]) > 0

    def is_region_colored(self, begin, end):
        return len(self._db[begin:end]) > 0
示例#33
0
    def __getitem__(self, index):
        with numpy_seed('GNNEvalDataset', self.seed, self.epoch, index):
            local_interval = IntervalTree()
            edge = self.graph[index]
            head = edge[GraphDataset.HEAD_ENTITY]
            tail = edge[GraphDataset.TAIL_ENTITY]

            start = edge[GraphDataset.START_BLOCK]
            end = edge[GraphDataset.END_BLOCK]
            local_interval.addi(start, end)
            head_neighbors = self.graph.get_neighbors(head)
            tail_neighbors = self.graph.get_neighbors(tail)

            mutual_neighbors = np.intersect1d(head_neighbors,
                                              tail_neighbors,
                                              assume_unique=True)
            if len(mutual_neighbors) == 0:
                return None

            found_supporting = False
            random_mutual = np.random.permutation(mutual_neighbors)

            for chosen_mutual in random_mutual:
                support1, local_interval = self.sample_relation_statement(
                    head, chosen_mutual, local_interval)
                support2, local_interval = self.sample_relation_statement(
                    chosen_mutual, tail, local_interval)

                if support1 is None or support2 is None:
                    continue
                else:
                    found_supporting = True
                    break

            if found_supporting is False:
                return None

        item = {
            'target':
            self.annotated_text.annotate_relation(*(edge.numpy())),
            'support': [
                self.annotated_text.annotate_relation(*(support1)),
                self.annotated_text.annotate_relation(*(support2))
            ],
            'entities': {
                'A': head,
                'B': tail,
                'C': chosen_mutual
            }
        }

        return item
示例#34
0
    def _interval_tree(self, mappings, chrom_length):
        """Assemble an interval tree from the mapping information.
        
        An interval tree is a tree data structure that allows to 
        efficiently find all intervals that overlap with any given 
        interval or point, often used for windowing queries.
            (see https://en.wikipedia.org/wiki/Interval_tree)
        
        The mapping information is a collection where 
        each item has the shape::
            {'mapped': 
               {'assembly': 'GRCh38',
                'coord_system': 'chromosome',
                'end': 1039365,
                'seq_region_name': 'X',
                'start': 1039265,
                'strand': 1},
             'original': 
               {'assembly': 'GRCh37',
                'coord_system': 'chromosome',
                'end': 1000100,
                'seq_region_name': 'X',
                'start': 1000000,
                'strand': 1}}
        """
        interval_tree = IntervalTree()

        for item in mappings:
            # Assemble the interval tree.
            # Each item describes a mapping of
            # regions btw both assemblies.
            from_ = item['original']
            to = item['mapped']

            # Need to modify to represent a half open
            # interval (as [a,b) instead of [a,b])
            from_region = from_['start'], from_['end'] + 1

            if to['strand'] == +1:
                to_region = to['start'], to['end']
            else:
                # Handle mappings to the reverse strand
                # (Translate them to the forward strand)
                # Visual aid to the transformation:
                #  1  2  3  4  5  6  7  8  9 10
                #  |  |  |  |  |  |  |  |  |  |
                # 10 9  8  7  6  5  4  3  2  1
                to_region = (chrom_length - to['end'] + 1,
                             chrom_length - to['start'] + 1)

            interval_tree.addi(*from_region, data=to_region)
        return interval_tree
示例#35
0
def build_search_tree(data_loader, chrom_number, start_pos, end_pos, source):
    '''
    returns an interval tree of all records with the specified
    chromosome and with start positions within a given interval
    '''
    start_time = timeit.default_timer()
    query = {
        "query": {
            "bool": {
                "must": [{
                    "range": {
                        "start": {
                            "gte": start_pos,
                            "lt": end_pos
                        }
                    }
                }, {
                    "match": {
                        "chrom_number": chrom_number
                    }
                }]
            }
        },
        "fields": ["_source", "_size"]
    }

    results = data_loader.es_tools.scan(query)

    (source_key, source_value) = source.items()[0]

    records_tree = IntervalTree()
    records_from_file = []
    for record in results:
        [start, end] = [record["_source"]["start"], record["_source"]["end"]]
        record["_source"]["record_id"] = record["_id"]
        record["_size"] = record["_size"]
        # records_tree[start:end+1] = record
        records_tree.addi(start, end + 1, record)
        if record["_source"][source_key] == source_value:
            records_from_file.append(record)

    end_time = timeit.default_timer()
    logging.debug(
        "Generated record tree with %d records for chromosome %s " +
        "and positions range %d - %d, list of %d records " +
        "from source file %s in %f seconds.",
        len(records_tree), chrom_number, start_pos, end_pos,
        len(records_from_file), str(source), end_time - start_time)
    return {
        "records_tree": records_tree,
        "records_from_file": records_from_file
    }
示例#36
0
def filter_intervals(intervals):
    it = IntervalTree()

    intervals_filtered = []
    for start, end in intervals:
        #if it.search(start, end):
        if it.overlap(start, end):
            pass
        else:
            it.addi(start, end, 1)
            #it.add(start, end, 1)
            intervals_filtered.append((start, end))
    return sorted(intervals_filtered, key=lambda tup: tup[0])
示例#37
0
def section_markup(markup, mode=HTML):
    arcs = []
    for source, target, type in markup.deps:
        if type == ROOT:
            continue

        if source < target:
            start, stop = source, target
            direction = RIGHT
        else:
            start, stop = target, source
            direction = LEFT

        arc = Arc(start, stop, direction, type, level=None)
        arcs.append(arc)

    # order
    arcs = sorted(arcs, key=Arc.layout_order)

    # level
    intervals = Intervals()
    for arc in arcs:
        stop = arc.stop
        if mode == ASCII:
            stop += 1  # in ascii mode include stop
        intervals.addi(arc.start, stop, arc)

    for arc in arcs:
        selected = intervals.overlap(arc.start, arc.stop)
        arc.level = get_free_level(selected)

    # group
    sections = defaultdict(list)
    for arc in arcs:
        start, stop, direction, type, level = arc
        parent = id(arc)
        for index in range(start, stop + 1):
            if index == start:
                part = BEGIN if direction == RIGHT else END
            elif index == stop:
                part = END if direction == RIGHT else BEGIN
            else:
                part = INSIDE

            section = ArcSection(part, direction, type, level, parent)
            sections[index].append(section)

    for index, word in enumerate(markup.words):
        arcs = sections[index]
        arcs = sorted(arcs, key=Arc.level_order)
        yield DepMarkupSection(word, arcs)
示例#38
0
def main():
    gtf_ref_file = sys.argv[1]
    sam_file = sys.argv[2]

    r = re.compile(r'\s*;?\s+')

    genes = []
    genes_by_position = IntervalTree()

    print("Loading reference...")
    with open(gtf_ref_file, 'r') as fi:
        for raw_line in fi:
            if not raw_line.startswith('#!'):
                fields = r.split(raw_line)
                type = fields[2]
                if type == 'gene':
                    chromosome = trim(fields[0])
                    gene_id = trim(fields[fields.index('gene_id') + 1])
                    gene_name = trim(fields[fields.index('gene_name') + 1])
                    start_position = int(fields[3])
                    end_position = int(fields[4])
                    gene = Gene(chromosome, gene_id, gene_name)
                    genes.append(gene)
                    # end_position not included: I had to add one because IntervalTree does not support
                    # (x, x) intervals
                    genes_by_position.addi(start_position, end_position + 1,
                                           gene)

    print("Counting reads...")
    with open(sam_file, 'r') as fi:
        for raw_line in fi:
            fields = r.split(raw_line)
            chromosome = fields[2]
            position = int(fields[3])
            cigar = fields[5]
            ref_len = cigar_to_reference_length(
                cigar)  # Length of the reference segment
            interested_genes = genes_by_position[position:position + ref_len - 1] if ref_len != 1\
                else genes_by_position[position]
            for interested_gene_interval in interested_genes:
                # The end of the interval is not included. We need also to check if the chromosome is the same
                if position == interested_gene_interval.end or chromosome != gene.chromosome:
                    continue
                gene = interested_gene_interval.data
                gene.count = gene.count + 1

    for gene in genes:
        # Since there are a lot of genes without any read, let's print only the ones which have at least one read
        if gene.count != 0:
            print("%2s\t%-20s\t\t%-16s\t\t%d" %
                  (gene.chromosome, gene.id, gene.name, gene.count))
示例#39
0
def site_intervaltree(seq, enzyme):
    """
    Initialise an intervaltree representation of an enzyme's cutsites across a given sequence.
    Whether a position involves a cutsite can be queried with tree[x] or tree[x1:x2].
    :param seq: the sequence to digest
    :param enzyme: the restriction enzyme used in digestion
    :return: an intervaltree of cutsites.
    """
    tr = IntervalTree()
    size = enzyme.size
    offset = enzyme.fst3
    for si in enzyme.search(seq):
        start = si + offset - 1
        tr.addi(start, start + size)
    return tr
示例#40
0
def test_interval_insersion_67():
    intervals = (
        (3657433088, 3665821696),
        (2415132672, 2415394816),
        (201326592, 268435456),
        (163868672, 163870720),
        (3301965824, 3303014400),
        (4026531840, 4294967296),
        (3579899904, 3579904000),
        (3439329280, 3443523584),
        (3431201536, 3431201664),
        (3589144576, 3589275648),
        (2531000320, 2531033088),
        (4187287552, 4187291648),
        (3561766912, 3561783296),
        (3046182912, 3046187008),
        (3506438144, 3506962432),
        (3724953872, 3724953888),
        (3518234624, 3518496768),
        (3840335872, 3840344064),
        (3492279181, 3492279182),
        (3447717888, 3456106496),
        (3589390336, 3589398528),
        (3486372962, 3486372963),
        (3456106496, 3472883712),
        (3508595496, 3508595498),
        (3511853376, 3511853440),
        (3452226160, 3452226168),
        (3544510720, 3544510736),
        (3525894144, 3525902336),
        (3524137920, 3524137984),
        (3508853334, 3508853335),
        (3467337728, 3467341824),
        (3463212256, 3463212260),
        (3446643456, 3446643712),
        (3473834176, 3473834240),
        (3487039488, 3487105024),
        (3444686112, 3444686144),
        (3459268608, 3459276800),
        (3483369472, 3485466624),
    )
    tree = IntervalTree()
    for interval in intervals:
        tree.addi(*interval)
    tree.verify()
示例#41
0
def test_duplicate_insert():
    tree = IntervalTree()

    # string data
    tree[-10:20] = "arbitrary data"
    contents = frozenset([Interval(-10, 20, "arbitrary data")])

    assert len(tree) == 1
    assert tree.items() == contents

    tree.addi(-10, 20, "arbitrary data")
    assert len(tree) == 1
    assert tree.items() == contents

    tree.add(Interval(-10, 20, "arbitrary data"))
    assert len(tree) == 1
    assert tree.items() == contents

    tree.update([Interval(-10, 20, "arbitrary data")])
    assert len(tree) == 1
    assert tree.items() == contents

    # None data
    tree[-10:20] = None
    contents = frozenset([
        Interval(-10, 20),
        Interval(-10, 20, "arbitrary data"),
    ])

    assert len(tree) == 2
    assert tree.items() == contents

    tree.addi(-10, 20)
    assert len(tree) == 2
    assert tree.items() == contents

    tree.add(Interval(-10, 20))
    assert len(tree) == 2
    assert tree.items() == contents

    tree.update([Interval(-10, 20), Interval(-10, 20, "arbitrary data")])
    assert len(tree) == 2
    assert tree.items() == contents
示例#42
0
def test_brackets_vs_overlap():
    it = IntervalTree()
    it.addi(1, 3, "dude")
    it.addi(2, 4, "sweet")
    it.addi(6, 9, "rad")
    for iobj in it:
        assert it[iobj.begin:iobj.end] == it.overlap(iobj.begin, iobj.end)
示例#43
0
def test_original_sequence():
    t = IntervalTree()
    t.addi(17.89,21.89)
    t.addi(11.53,16.53)
    t.removei(11.53,16.53)
    t.removei(17.89,21.89)
    t.addi(-0.62,4.38)
    t.addi(9.24,14.24)
    t.addi(4.0,9.0)
    t.removei(-0.62,4.38)
    t.removei(9.24,14.24)
    t.removei(4.0,9.0)
    t.addi(12.86,17.86)
    t.addi(16.65,21.65)
    t.removei(12.86,17.86)
示例#44
0
def test_minimal_sequence():
    t = IntervalTree()
    t.addi(-0.62, 4.38)  # becomes root
    t.addi(9.24, 14.24)  # right child

    ## Check that the tree structure is like this:
    # t.print_structure()
    # Node<-0.62, depth=2, balance=1>
    #  Interval(-0.62, 4.38)
    # >:  Node<9.24, depth=1, balance=0>
    #      Interval(9.24, 14.24)
    root = t.top_node
    assert root.s_center == set([Interval(-0.62, 4.38)])
    assert root.right_node.s_center == set([Interval(9.24, 14.24)])
    assert not root.left_node

    t.verify()

    # This line left an empty node when drotate() failed to promote
    # Intervals properly:
    t.addi(4.0, 9.0)
    t.print_structure()
    t.verify()
示例#45
0
def original_print():
    it = IntervalTree()
    it.addi(1, 3, "dude")
    it.addi(2, 4, "sweet")
    it.addi(6, 9, "rad")
    for iobj in it:
        print(it[iobj.begin, iobj.end])  # set(), should be using :

    for iobj in it:
        print(it.envelop(iobj.begin, iobj.end))
def test_issue5():
    # Issue #5, https://github.com/konstantint/PyIntervalTree/issues/5
    from intervaltree import IntervalTree
    t = IntervalTree()
    t.addi(-46.0, 31.0, 'test')
    t.addi(-20.0, 29.0, 'test')
    t.addi(1.0, 9.0, 'test')
    t.addi(-3.0, 6.0, 'test')
    t.removei(1.0, 9.0, 'test')
    t.removei(-20.0, 29.0, 'test')
    t.removei(-46.0, 31.0, 'test')
    assert len(t) == 1
示例#47
0
def test_small_tree_score():
    # inefficiency score for trees of len() <= 2 should be 0.0
    t = IntervalTree()
    assert t.score() == 0.0

    t.addi(1, 4)
    assert t.score() == 0.0

    t.addi(2, 5)
    assert t.score() == 0.0

    t.addi(1, 100)  # introduces inefficiency, b/c len(s_center) > 1
    assert t.score() != 0.0
示例#48
0
文件: decoder.py 项目: mesheven/pyOCD
class DwarfAddressDecoder(object):
    def __init__(self, elf):
        assert isinstance(elf, ELFFile)
        self.elffile = elf

        if not self.elffile.has_dwarf_info():
            raise Exception("No DWARF debug info available")

        self.dwarfinfo = self.elffile.get_dwarf_info()

        self.subprograms = None
        self.function_tree = None
        self.line_tree = None

        # Build indices.
        self._get_subprograms()
        self._build_function_search_tree()
        self._build_line_search_tree()

    def get_function_for_address(self, addr):
        try:
            return sorted(self.function_tree[addr])[0].data
        except IndexError:
            return None

    def get_line_for_address(self, addr):
        try:
            return sorted(self.line_tree[addr])[0].data
        except IndexError:
            return None

    def _get_subprograms(self):
        self.subprograms = []
        for CU in self.dwarfinfo.iter_CUs():
            self.subprograms.extend([d for d in CU.iter_DIEs() if d.tag == 'DW_TAG_subprogram'])

    def _build_function_search_tree(self):
        self.function_tree = IntervalTree()
        for prog in self.subprograms:
            try:
                name = prog.attributes['DW_AT_name'].value
                low_pc = prog.attributes['DW_AT_low_pc'].value
                high_pc = prog.attributes['DW_AT_high_pc'].value

                # Skip subprograms excluded from the link.
                if low_pc == 0:
                    continue

                # If high_pc is not explicitly an address, then it's an offset from the
                # low_pc value.
                if prog.attributes['DW_AT_high_pc'].form != 'DW_FORM_addr':
                    high_pc = low_pc + high_pc

                fninfo = FunctionInfo(name=name, subprogram=prog, low_pc=low_pc, high_pc=high_pc)

                self.function_tree.addi(low_pc, high_pc, fninfo)
            except KeyError:
                pass

    def _build_line_search_tree(self):
        self.line_tree = IntervalTree()
        for cu in self.dwarfinfo.iter_CUs():
            lineprog = self.dwarfinfo.line_program_for_CU(cu)
            prevstate = None
            skipThisSequence = False
            for entry in lineprog.get_entries():
                # Look for a DW_LNE_set_address command with a 0 address. This indicates
                # code that is not actually included in the link.
                #
                # TODO: find a better way to determine the code is really not present and
                #       doesn't have a real address of 0
                if entry.is_extended and entry.command == DW_LNE_set_address \
                        and len(entry.args) == 1 and entry.args[0] == 0:
                    skipThisSequence = True

                # We're interested in those entries where a new state is assigned
                if entry.state is None:
                    continue

                # Looking for a range of addresses in two consecutive states.
                if prevstate and not skipThisSequence:
                    fileinfo = lineprog['file_entry'][prevstate.file - 1]
                    filename = fileinfo.name
                    dirname = lineprog['include_directory'][fileinfo.dir_index - 1]
                    info = LineInfo(cu=cu, filename=filename, dirname=dirname, line=prevstate.line)
                    fromAddr = prevstate.address
                    toAddr = entry.state.address
                    try:
                        if fromAddr != 0 and toAddr != 0:
                            if fromAddr == toAddr:
                                toAddr += 1
                            self.line_tree.addi(fromAddr, toAddr, info)
                    except:
                        logging.debug("Problematic lineprog:")
                        self._dump_lineprog(lineprog)
                        raise

                if entry.state.end_sequence:
                    prevstate = None
                    skipThisSequence = False
                else:
                    prevstate = entry.state

    def _dump_lineprog(self, lineprog):
        for i, e in enumerate(lineprog.get_entries()):
            s = e.state
            if s is None:
                logging.debug("%d: cmd=%d ext=%d args=%s", i, e.command, int(e.is_extended), repr(e.args))
            else:
                logging.debug("%d: %06x %4d stmt=%1d block=%1d end=%d file=[%d]%s", i, s.address, s.line, s.is_stmt, int(s.basic_block), int(s.end_sequence), s.file, lineprog['file_entry'][s.file-1].name)

    def dump_subprograms(self):
        for prog in self.subprograms:
            name = prog.attributes['DW_AT_name'].value
            try:
                low_pc = prog.attributes['DW_AT_low_pc'].value
            except KeyError:
                low_pc = 0
            try:
                high_pc = prog.attributes['DW_AT_high_pc'].value
            except KeyError:
                high_pc = 0xffffffff
            filename = os.path.basename(prog._parent.attributes['DW_AT_name'].value.replace('\\', '/'))
            logging.debug("%s%s%08x %08x %s", name, (' ' * (50-len(name))), low_pc, high_pc, filename)
示例#49
0
文件: cache.py 项目: flit/pyOCD
class MemoryCache(object):
    """! @brief Memory cache.
    
    Maintains a cache of target memory. The constructor is passed a backing DebugContext object that
    will be used to fill the cache.
    
    The cache is invalidated whenever the target has run since the last cache operation (based on run
    tokens). If the target is currently running, all accesses cause the cache to be invalidated.
    
    The target's memory map is referenced. All memory accesses must be fully contained within a single
    memory region, or a MemoryAccessError will be raised. However, if an access is outside of all regions,
    the access is passed to the underlying context unmodified. When an access is within a region, that
    region's cacheability flag is honoured.
    """
    
    def __init__(self, context, core):
        self._context = context
        self._core = core
        self._run_token = -1
        self._log = LOG.getChild('memcache')
        self._reset_cache()

    def _reset_cache(self):
        self._cache = IntervalTree()
        self._metrics = CacheMetrics()

    def _check_cache(self):
        """! @brief Invalidates the cache if appropriate."""
        if self._core.is_running():
            self._log.debug("core is running; invalidating cache")
            self._reset_cache()
        elif self._run_token != self._core.run_token:
            self._dump_metrics()
            self._log.debug("out of date run token; invalidating cache")
            self._reset_cache()
            self._run_token = self._core.run_token

    def _get_ranges(self, addr, count):
        """! @brief Splits a memory address range into cached and uncached subranges.
        @return Returns a 2-tuple with the first element being a set of Interval objects for each
          of the cached subranges. The second element is a set of Interval objects for each of the
          non-cached subranges.
        """
        cached = self._cache.overlap(addr, addr + count)
        uncached = {Interval(addr, addr + count)}
        for cachedIv in cached:
            newUncachedSet = set()
            for uncachedIv in uncached:

                # No overlap.
                if cachedIv.end < uncachedIv.begin or cachedIv.begin > uncachedIv.end:
                    newUncachedSet.add(uncachedIv)
                    continue

                # Begin segment.
                if cachedIv.begin - uncachedIv.begin > 0:
                    newUncachedSet.add(Interval(uncachedIv.begin, cachedIv.begin))

                # End segment.
                if uncachedIv.end - cachedIv.end > 0:
                    newUncachedSet.add(Interval(cachedIv.end, uncachedIv.end))
            uncached = newUncachedSet
        return cached, uncached

    def _read_uncached(self, uncached):
        """! "@brief Reads uncached memory ranges and updates the cache.
        @return A list of Interval objects is returned. Each Interval has its @a data attribute set
          to a bytearray of the data read from target memory.
        """
        uncachedData = []
        for uncachedIv in uncached:
            data = self._context.read_memory_block8(uncachedIv.begin, uncachedIv.end - uncachedIv.begin)
            iv = Interval(uncachedIv.begin, uncachedIv.end, bytearray(data))
            self._cache.add(iv) # TODO merge contiguous cached intervals
            uncachedData.append(iv)
        return uncachedData

    def _update_metrics(self, cached, uncached, addr, size):
        cachedSize = 0
        for iv in cached:
            begin = iv.begin
            end = iv.end
            if iv.begin < addr:
                begin = addr
            if iv.end > addr + size:
                end = addr + size
            cachedSize += end - begin

        uncachedSize = sum((iv.end - iv.begin) for iv in uncached)

        self._metrics.reads += 1
        self._metrics.hits += cachedSize
        self._metrics.misses += uncachedSize

    def _dump_metrics(self):
        if self._metrics.total > 0:
            self._log.debug("%d reads, %d bytes [%d%% hits, %d bytes]; %d bytes written",
                self._metrics.reads, self._metrics.total, self._metrics.percent_hit,
                self._metrics.hits, self._metrics.writes)
        else:
            self._log.debug("no reads")

    def _read(self, addr, size):
        """! @brief Performs a cached read operation of an address range.
        @return A list of Interval objects sorted by address.
        """
        # Get the cached and uncached subranges of the requested read.
        cached, uncached = self._get_ranges(addr, size)
        self._update_metrics(cached, uncached, addr, size)

        # Read any uncached ranges.
        uncachedData = self._read_uncached(uncached)

        # Merged cached with data we just read
        combined = list(cached) + uncachedData
        combined.sort(key=lambda x: x.begin)
        return combined

    def _merge_data(self, combined, addr, size):
        """! @brief Extracts data from the intersection of an address range across a list of interval objects.
        
        The range represented by @a addr and @a size are assumed to overlap the intervals. The first
        and last interval in the list may have ragged edges not fully contained in the address range, in
        which case the correct slice of those intervals is extracted.
        
        @param self
        @param combined List of Interval objects forming a contiguous range. The @a data attribute of
          each interval must be a bytearray.
        @param addr Start address. Must be within the range of the first interval.
        @param size Number of bytes. (@a addr + @a size) must be within the range of the last interval.
        @return A single bytearray object with all data from the intervals that intersects the address
          range.
        """
        result = bytearray()
        resultAppend = bytearray()

        # Check for fully contained subrange.
        if len(combined) and combined[0].begin < addr and combined[0].end > addr + size:
            offset = addr - combined[0].begin
            endOffset = offset + size
            result = combined[0].data[offset:endOffset]
            return result
        
        # Take slice of leading ragged edge.
        if len(combined) and combined[0].begin < addr:
            offset = addr - combined[0].begin
            result += combined[0].data[offset:]
            combined = combined[1:]
        # Take slice of trailing ragged edge.
        if len(combined) and combined[-1].end > addr + size:
            offset = addr + size - combined[-1].begin
            resultAppend = combined[-1].data[:offset]
            combined = combined[:-1]

        # Merge.
        for iv in combined:
            result += iv.data
        result += resultAppend

        return result

    def _update_contiguous(self, cached, addr, value):
        size = len(value)
        end = addr + size
        leadBegin = addr
        leadData = bytearray()
        trailData = bytearray()
        trailEnd = end

        if cached[0].begin < addr and cached[0].end > addr:
            offset = addr - cached[0].begin
            leadData = cached[0].data[:offset]
            leadBegin = cached[0].begin
        if cached[-1].begin < end and cached[-1].end > end:
            offset = end - cached[-1].begin
            trailData = cached[-1].data[offset:]
            trailEnd = cached[-1].end

        self._cache.remove_overlap(addr, end)

        data = leadData + value + trailData
        self._cache.addi(leadBegin, trailEnd, data)

    def _check_regions(self, addr, count):
        """! @return A bool indicating whether the given address range is fully contained within
              one known memory region, and that region is cacheable.
        @exception MemoryAccessError Raised if the access is not entirely contained within a single region.
        """
        regions = self._core.memory_map.get_intersecting_regions(addr, length=count)

        # If no regions matched, then allow an uncached operation.
        if len(regions) == 0:
            return False

        # Raise if not fully contained within one region.
        if len(regions) > 1 or not regions[0].contains_range(addr, length=count):
            raise MemoryAccessError("individual memory accesses must not cross memory region boundaries")

        # Otherwise return whether the region is cacheable.
        return regions[0].is_cacheable

    def read_memory(self, addr, transfer_size=32, now=True):
        # TODO use more optimal underlying read_memory call
        if transfer_size == 8:
            data = self.read_memory_block8(addr, 1)[0]
        elif transfer_size == 16:
            data = conversion.byte_list_to_u16le_list(self.read_memory_block8(addr, 2))[0]
        elif transfer_size == 32:
            data = conversion.byte_list_to_u32le_list(self.read_memory_block8(addr, 4))[0]

        if now:
            return data
        else:
            def read_cb():
                return data
            return read_cb

    def read_memory_block8(self, addr, size):
        if size <= 0:
            return []

        self._check_cache()

        # Validate memory regions.
        if not self._check_regions(addr, size):
            self._log.debug("range [%x:%x] is not cacheable", addr, addr+size)
            return self._context.read_memory_block8(addr, size)

        # Get the cached and uncached subranges of the requested read.
        combined = self._read(addr, size)

        # Extract data out of combined intervals.
        result = list(self._merge_data(combined, addr, size))
        assert len(result) == size, "result size ({}) != requested size ({})".format(len(result), size)
        return result

    def read_memory_block32(self, addr, size):
        return conversion.byte_list_to_u32le_list(self.read_memory_block8(addr, size*4))

    def write_memory(self, addr, value, transfer_size=32):
        if transfer_size == 8:
            return self.write_memory_block8(addr, [value])
        elif transfer_size == 16:
            return self.write_memory_block8(addr, conversion.u16le_list_to_byte_list([value]))
        elif transfer_size == 32:
            return self.write_memory_block8(addr, conversion.u32le_list_to_byte_list([value]))

    def write_memory_block8(self, addr, value):
        if len(value) <= 0:
            return

        self._check_cache()

        # Validate memory regions.
        cacheable = self._check_regions(addr, len(value))

        # Write to the target first, so if it fails we don't update the cache.
        result = self._context.write_memory_block8(addr, value)

        if cacheable:
            size = len(value)
            end = addr + size
            cached = sorted(self._cache.overlap(addr, end), key=lambda x:x.begin)
            self._metrics.writes += size

            if len(cached):
                # Write data is entirely within a single cached interval.
                if addr >= cached[0].begin and end <= cached[0].end:
                    beginOffset = addr - cached[0].begin
                    endOffset = beginOffset + size
                    cached[0].data[beginOffset:endOffset] = value

                else:
                    self._update_contiguous(cached, addr, bytearray(value))
            else:
                # No cached data in this range, so just add the entire interval.
                self._cache.addi(addr, end, bytearray(value))

        return result

    def write_memory_block32(self, addr, data):
        return self.write_memory_block8(addr, conversion.u32le_list_to_byte_list(data))

    def invalidate(self):
        self._reset_cache()
示例#50
0
def test_debug_sequence():
    t = IntervalTree()
    t.addi(6.37,11.37)
    t.verify()
    t.addi(12.09,17.09)
    t.verify()
    t.addi(5.68,11.58)
    t.verify()
    t.removei(6.37,11.37)
    t.verify()
    t.addi(13.23,18.23)
    t.verify()
    t.removei(12.09,17.09)
    t.verify()
    t.addi(4.29,8.29)
    t.verify()
    t.removei(13.23,18.23)
    t.verify()
    t.addi(12.04,17.04)
    t.verify()
    t.addi(9.39,13.39)
    t.verify()
    t.removei(5.68,11.58)
    t.verify()
    t.removei(4.29,8.29)
    t.verify()
    t.removei(12.04,17.04)
    t.verify()
    t.addi(5.66,9.66)     # Value inserted here
    t.verify()
    t.addi(8.65,13.65)
    t.verify()
    t.removei(9.39,13.39)
    t.verify()
    t.addi(16.49,20.83)
    t.verify()
    t.addi(11.42,16.42)
    t.verify()
    t.addi(5.38,10.38)
    t.verify()
    t.addi(3.57,9.47)
    t.verify()
    t.removei(8.65,13.65)
    t.verify()
    t.removei(5.66,9.66)    # Deleted here
    t.verify()
class Audio(object):

    def __init__(self):
        self.sentences = []
        self.pitch_interval = IntervalTree()
        self.talk_id = 0
        self.group_name = None

        self.token_count = None

        self.PITCH_FILTER = 300.0
        self.YAAFE_STEP_SIZE = 512.0
        self.TED_AUDIO_SAMPLE_RATE = 16000.0

    def get_tokens(self):
        tokens = []
        for sentence in self.sentences:
            tokens.extend(sentence.tokens)
        return tokens

    def add_sentence(self, sentence):
        self.sentences.append(sentence)

    def build_interval_tree(self):
        self.token_count = 0
        for token in self.get_tokens():
            if not token.is_punctuation():
                self.token_count += 1
                self.pitch_interval.addi(token.begin, token.begin + token.duration, token)

    def parse_pitch_feature(self, filename):
        with open(filename, "r") as file_:
            for line_unenc in file_:
                # parse line
                line = unicode(line_unenc, errors='ignore')
                line = line.rstrip()

                line_parts = line.split(" ")
                second = float(line_parts[0])
                pitch_level = float(line_parts[1])

                if pitch_level < self.PITCH_FILTER:
                    try:
                        token = next(iter(self.pitch_interval[second])).data
                        token.append_pitch_level(pitch_level)
                    except:
                        continue

        token_without_pitch = 0.0
        for sentence in self.sentences:
            avg_pitch = sentence.get_avg_pitch_level()
            for token in sentence.get_tokens():
                if not token.is_punctuation():
                    try:
                        token.pitch = (reduce(lambda x, y: x + y, token.pitch_levels) / len(token.pitch_levels)) - avg_pitch
                    except:
                        token_without_pitch += 1
                        token.pitch = 0.0

        # print("%2.2f %% of tokens had no pitch level." % (token_without_pitch / self.token_count * 100))

    def parse_energy_feature(self, filename):
        intervall = self.YAAFE_STEP_SIZE / self.TED_AUDIO_SAMPLE_RATE

        with open(filename, "r") as file_:
            i = -1
            for line_unenc in file_:
                # parse line
                line = unicode(line_unenc, errors='ignore')

                if line.startswith("%"):
                    continue

                i += 1
                energy_level = float(line.rstrip())

                try:
                    token = next(iter(self.pitch_interval[i * intervall])).data
                    token.append_energy_level(energy_level)
                except:
                    continue

        token_without_energy = 0.0
        for sentence in self.sentences:
            avg_energy = sentence.get_avg_energy_level()
            for token in sentence.get_tokens():
                if not token.is_punctuation():
                    try:
                        token.energy = (reduce(lambda x, y: x + y, token.energy_levels) / len(token.energy_levels)) - avg_energy
                    except:
                        token_without_energy += 1
                        token.energy = 0.0

        # print("%2.2f %% of tokens had no energy level." % (token_without_energy / self.token_count * 100))


    def normalize(self):
        all_pauses = np.zeros(self.token_count, dtype = np.float32)
        all_pitches = np.zeros(self.token_count, dtype = np.float32)
        all_energies = np.zeros(self.token_count, dtype = np.float32)

        i = 0
        for token in self.get_tokens():
            if not token.is_punctuation():
                # restrict pause length to 2 seconds at most
                token.pause_before = min(token.pause_before, 2)
                token.pause_after = min(token.pause_after, 2)

                all_pauses[i] = token.pause_before
                all_pitches[i] = token.pitch
                all_energies[i] = token.energy
                i += 1

        pause_mean = np.mean(all_pauses)
        pitch_mean = np.mean(all_pitches)
        energy_mean = np.mean(all_energies)

        pause_std = np.std(all_pauses)
        pitch_std = np.std(all_pitches)
        energy_std = np.std(all_energies)

        for token in self.get_tokens():
            if not token.is_punctuation():
                token.set_pause_before((token.pause_before - pause_mean) / pause_std)
                token.set_pause_after((token.pause_after - pause_mean) / pause_std)
                token.set_pitch((token.pitch - pitch_mean) / pitch_std)
                token.set_energy((token.energy - energy_mean) / energy_std)

    def __str__(self):
        sentences_str = ''.join(map(str, self.sentences))
        return sentences_str
示例#52
0
文件: decoder.py 项目: mesheven/pyOCD
class ElfSymbolDecoder(object):
    def __init__(self, elf):
        assert isinstance(elf, ELFFile)
        self.elffile = elf

        self.symtab = self.elffile.get_section_by_name('.symtab')
        self.symcount = self.symtab.num_symbols()
        self.symbol_dict = {}
        self.symbol_tree = None

        # Build indices.
        self._build_symbol_search_tree()
        self._process_arm_type_symbols()

    def get_elf(self):
        return self.elffile

    def get_symbol_for_address(self, addr):
        try:
            return sorted(self.symbol_tree[addr])[0].data
        except IndexError:
            return None
    
    def get_symbol_for_name(self, name):
        try:
            return self.symbol_dict[name]
        except KeyError:
            return None

    def _build_symbol_search_tree(self):
        self.symbol_tree = IntervalTree()
        symbols = self.symtab.iter_symbols()
        for symbol in symbols:
            # Only look for functions and objects.
            sym_type = symbol.entry['st_info']['type']
            if sym_type not in ['STT_FUNC', 'STT_OBJECT']:
                continue

            sym_value = symbol.entry['st_value']
            sym_size = symbol.entry['st_size']

            # Cannot put an empty interval into the tree, so ensure symbols have
            # at least a size of 1.
            real_sym_size = sym_size
            if sym_size == 0:
                sym_size = 1

            syminfo = SymbolInfo(name=symbol.name, address=sym_value, size=real_sym_size, type=sym_type)

            # Add to symbol dict.
            self.symbol_dict[symbol.name] = syminfo
            
            # Add to symbol tree.
            self.symbol_tree.addi(sym_value, sym_value+sym_size, syminfo)

    def _process_arm_type_symbols(self):
        type_symbols = self._get_arm_type_symbol_iter()
#         map(print, imap(lambda x:"%s : 0x%x" % (x.name, x['st_value']), type_symbols))

    def _get_arm_type_symbol_iter(self):
        # Scan until we find $m symbol.
        i = 1
        while i < self.symcount:
            symbol = self.symtab.get_symbol(i)
            if symbol.name == '$m':
                break
            i += 1
        if i >= self.symcount:
            return
        n = symbol['st_value']
        return islice(self.symtab.iter_symbols(), i, n)
示例#53
0
class virtual:
    def __init__(self):
        #mapea id drawables con su respectivo drawable
        self.idToDrawable = {}

        self.idToInterval= {}
        self.tags = {}

        #contine pares (intervaloX,idDrawable) que representan helperBoxs de elementos en espacio virtual
        self.intervalTreeX = IntervalTree()

        self.vista = None
        self.currentLocalId = 0


        self.stringTofunction = {}
        self.drawableInMemory=None

        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(logging.DEBUG)
        fh = logging.FileHandler('virtualScreen.log')
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        fh.setFormatter(formatter)
        self.logger.addHandler(fh)


    def setCommandString(self,command,function):
        self.logger.info('Adding new command %s for file recovery ',command)
        self.stringTofunction[command] = function


    def setView(self,vista):
        self.logger.info('Setting new view ')

        self.vista = vista
        self.setCommandString('setTag',lambda args : self.setTagLast(*args) )
        self.setCommandString('SETID',lambda args : self.placeDrawable(self.drawableInMemory,args[0]) )
        self.setCommandString('setViewWidthHeight',lambda args : self.vista.vistaSetWidthHeight(*args) )
        self.setCommandString('placeView',lambda args : self.vista.placeView(*args) )
        self.setCommandString('setViewScaleXY',lambda args : self.vista.setFactorXY(*args) )


        self.setCommandString('createRectangle',lambda args : self.setLastDrawableInMemory(self.createRectangle(*args,createId=False)) )
        self.setCommandString('createLine',lambda args : self.setLastDrawableInMemory(self.createLine(*args,createId=False)) )
        self.setCommandString('createGroup',lambda args : self.setLastDrawableInMemory(self.createGroup(*args,createId=False)) )
        self.setCommandString('createText', lambda args :self.setLastDrawableInMemory(self.createText(*args,createId=False)) )
        self.setCommandString('createPointDraw', lambda args : self.setLastDrawableInMemory(self.createPointDraw(*args,createId=False)) )

    def isVisible(self,drawable,intervalosView):
        viewIntervalX = intervalosView[0]
        viewIntervalY = intervalosView[1]

        intervaloQueryX= tuple([point[0] for point in drawable.calcHelperBox()])
        intervaloQueryY= tuple([point[1] for point in drawable.calcHelperBox()])

        return self.envision(intervaloQueryX,viewIntervalX) and self.envision(intervaloQueryY,viewIntervalY)

    def envision(self,queryInter,visInterval):
        #tres casos dentro de vision 0---1---1----0  o el caso 1-----0-------0-----1 o el caso 1------0------1
        #sean los 1 el cuadro de vision
        objetoContieneVista = lambda queryInter,visInterval : min(queryInter) <= min(visInterval) and max(visInterval) <= max(queryInter)
        vistaContieneObjeto =  lambda queryInter,visInterval  : (min(visInterval)  <= queryInter[0] <= max(visInterval)) or (min(visInterval)  <= queryInter[1] <= max(visInterval))

        return objetoContieneVista(queryInter,visInterval) or vistaContieneObjeto(queryInter,visInterval)


    def winfo_height(self):
        return self.vista.heigth
    def winfo_width(self):
        return self.vista.width

    def setLastDrawableInMemory(self,drawable):
        self.drawableInMemory=drawable

    #consigue todos los elementos en cuadrado
    def getSquare(self,p0,pf,tags=None):

        temp = []

        #consigue lista con intervalos en X dentro del cuadrado (o que pasen por este)
        #Debe ser siempre begin < end
        listaIntervalos = self.intervalTreeX.search(min(p0[0],pf[0]),max(p0[0],pf[0]))

        #esto te entrega lista tuplas ((x2,x2),idDrawable)
        for tupla in listaIntervalos:
            drawable= self.idToDrawable[tupla[2]]
            #Ahora descarta los que no sean consistentes respecto al intervalo Y
            intervaloY = tuple([point[1] for point in drawable.calcHelperBox()])
            if self.envision(intervaloY,(p0[1],pf[1])):
                temp.append(drawable)
        # print 'Elem without Filter ',str(temp)
        if not tags is None:
            return [elem for elem in temp if not self.getTagdrawable(elem) in tags]

        return temp



    """
    ---------------Funciones de creacion ------------------------------
    """
    def createLine(self,p0,pf,createId=True):
        self.logger.info('Creating line in %s %s',p0,pf)
        line = Line(self,self.vista,p0,pf)
        if createId:
            self.placeDrawable(line)
        return line

    def createRectangle(self,p0,pf,createId=True):
        self.logger.info('Creating rectangle in %s %s',p0,pf)
        rect = Rectangle(self,self.vista,p0,pf)
        if createId:
            self.placeDrawable(rect)
        return rect

    def createGroup(self,listaId=None,createId=True):
        self.logger.info('Creating Group from list %s',listaId)
        group = Group(self,self.vista)
        if not listaId is None:
            for id in listaId:
                group.add(self.idToDrawable[id])

        if createId:
            self.placeDrawable(group)
        return group

    def createText(self,p0,texto,createId=True):
        self.logger.info('Creating Text %s in %s',texto,p0)
        texto = TextDrawable(self,self.vista,p0,texto)
        if createId:
            self.placeDrawable(texto)
        return texto

    def createPointDraw(self,idGroup=None,createId=True):
        self.logger.info('Creating poinDraw from group %s',idGroup)
        pd = pointDraw(self,self.vista)
        if not idGroup is None:
            grupo = self.idToDrawable[idGroup]
            pd.addFromGroup(grupo)
        if createId:
            self.placeDrawable(pd)

        return pd

    def placeDrawable(self,drawable,id=None):
        self.logger.info('Placing drawable %s',str(drawable))
        if id is None:
            drawable.uniqueId = self.__getNewId()
        else:
            drawable.uniqueId = id
        drawable.draw()
        #ASEGURATE QUE LAS HELPERBOX ESTE BIEN HECHA
        helperBoxCords = drawable.calcHelperBox()
        # print 'helperbox ',helperBoxCords
        # print "helper yo interval ",helperBoxCords
        self.intervalTreeX.addi(helperBoxCords[0][0],helperBoxCords[1][0],drawable.uniqueId)
        self.idToInterval[drawable.uniqueId] = Interval(helperBoxCords[0][0],helperBoxCords[1][0],drawable.uniqueId)

        assert(self.idToInterval[drawable.uniqueId] == drawable.calcInterval())
        self.idToDrawable[drawable.uniqueId] = drawable


    def updatePosition(self,drawable):
        if self.idToDrawable.has_key(drawable.uniqueId):
            self.logger.info('Updating %s drawable %s ',drawable.uniqueId,str(drawable))
            try:
                self.intervalTreeX.remove(self.idToInterval[drawable.uniqueId])
            except Exception,e:
                print 'Error en borrar intervalo'
                self.logger.error('Cant remove interval %s exception %s',self.idToInterval[drawable.uniqueId],str(e))

            self.idToInterval.pop(drawable.uniqueId)

            helperBoxCords = drawable.calcHelperBox()
            self.intervalTreeX.addi(helperBoxCords[0][0],helperBoxCords[1][0],drawable.uniqueId)
            self.idToInterval[drawable.uniqueId] = Interval(helperBoxCords[0][0],helperBoxCords[1][0],drawable.uniqueId)
            assert(self.idToInterval[drawable.uniqueId] == drawable.calcInterval())

            self.logger.debug('New drawable interval %s %s %s ',helperBoxCords[0][0],helperBoxCords[1][0],drawable.uniqueId)

        else:
示例#54
0
class HistorySet(object):
    __slots__ = ('current', 'history')

    def __init__(self, values=(), *, time=None):
        time = time if time is not None else now()
        self.current = {v: time for v in values}
        self.history = IntervalTree()

    @staticmethod
    def from_intervals(intervals):
        result = HistorySet()
        for iv in intervals:
            result.add_interval(iv)

    def add_interval(self, iv):
        if iv.end is GreatestValue:
            self.current[iv.data] = iv.begin
        else:
            if iv.data in self.current and self.current[iv.data] <= iv.end:
                del self.current[iv.data]
            self.history.add(iv)

    def refine_history(self):
        """
        Scrub the internal IntervalTree history so that there are a minimum number of intervals.

        Any multiplicity of intervals with the same data value that covers a single contiguous range will
        be replaced with a single interval over that range.

        This is an expensive operation, both in time and memory, that should only be performed when the
        history is being modified carelessly, such as naively merging with the history from another HistorySet
        or adding and removing elements out of chronological order.

        Behavior for the HistorySet should be identical before and after calling refine_history(), but may be
        slightly faster and consume less memory afterwards. The only change will be that it should no longer
        return incorrect values for the effective added date of currently contained items after merging with
        history intervals.
        """
        self.history = IntervalTree(merge_interval_overlaps(self.history, self.current))

    def __getitem__(self, index):
        if type(index) is slice:
            if index.step is not None:
                raise ValueError("Slice indexing is used for intervals, which do not have a step.")
            iv = Interval(index.start, index.stop)
            result = {x.data for x in self.history[iv]}
            result.update(x[0] for x in self.current.items() if iv.overlaps(Interval(begin=x[1], end=None)))
        else:
            result = {x.data for x in self.history[index]}
            result.update(item_ for item_, time_ in self.current.items() if time_ <= index)
        return result

    def time_slice(self, begin, end):
        """
        Return an iterable over all the intervals intersecting the given half-open interval from begin to end,
        chopped to fit within it
        """
        if begin is None or end is None:
            raise ValueError("Both the beginning and end of the interval must be included")
        if end <= begin:
            raise ValueError("begin must be < end")
        for iv in self.history[begin:end]:
            yield Interval(begin=max(iv.begin, begin), end=min(iv.end, end), data=iv.data)
        for value, added in self.current.items():
            if added < end:
                yield Interval(begin=added, end=end, data=value)

    def intervals(self):
        """
        Return an iterator over all the intervals in this set. Currently contained values have intervals
        ending with a GreatestValue object.
        """
        yield from self.history
        end = GreatestValue
        for value, begin in self.current.items():
            yield Interval(begin=begin, end=end, data=value)

    def all_values(self):
        result = self.copy()
        for old in self.history:
            result.add(old.data)
        return result

    def item_added_time(self, value):
        return self.current[value]

    def ordered_by_addition(self, *, time=None):
        if time is None:
            result = list(self.current.items())
        else:
            result = [(x.begin, x.data) for x in self.history[time]]
            result.extend((added, item) for item, added in self.current.items() if added <= time)
        result.sort(key=itemgetter(0))
        return [x[1] for x in result]

    def add(self, value, *, time=None):
        time = time if time is not None else now()
        if value not in self.current or self.current[value] > time:
            self.current[value] = time

    def remove(self, value, *, time=None):
        self.history.addi(self.current.pop(value), time if time is not None else now(), value)

    def discard(self, value, *, time=None):
        if value in self.current:
            self.remove(value, time=time)

    def copy(self, *, time=None):
        if time is None:
            return set(self.current)
        else:
            return self[time]

    def members_in_interval(self, begin, end):
        return self[begin:end]

    def clear(self, *, time=None):
        time = time if time is not None else now()
        for item in self.current.items():
            self.history.addi(item[1], time, item[0])
        self.current.clear()

    def union(self, *others):
        result = self.copy()
        result.update(*others)
        return result

    def difference(self, *others):
        result = self.copy()
        result.difference_update(*others)
        return result

    def symmetric_difference(self, other):
        result = self.copy()
        result.symmetric_difference_update(other)
        return result

    def intersection(self, *others):
        result = self.copy()
        result.intersection_update(*others)
        return result

    def update(self, *others, time=None):
        time = time if time is not None else now()
        for other in others:
            for value in other:
                self.add(value, time=time)

    def difference_update(self, *others, time=None):
        time = time if time is not None else now()
        for other in others:
            for value in other:
                self.discard(value, time=time)

    def symmetric_difference_update(self, other, *, time=None):
        time = time if time is not None else now()
        for value in other:
            if value in self.current:
                self.remove(value, time=time)
            else:
                self.add(value, time=time)

    def intersection_update(self, *others, time=None):
        time = time if time is not None else now()
        toss = self.difference(*others)
        for value in toss:
            self.discard(value, time=time)

    def pop(self, *, time=None):
        time = time if time is not None else now()
        item = self.current.popitem()
        self.history.addi(item[1], time, item[0])
        return item[0]

    def isdisjoint(self, other):
        # noinspection PyUnresolvedReferences
        return self.current.keys().isdisjoint(other)

    def issubset(self, other):
        return other > self.current

    def issuperset(self, other):
        return other < self.current

    def __iter__(self):
        return iter(self.current)

    def __len__(self):
        return len(self.current)

    def __eq__(self, other):
        if isinstance(other, (set, frozenset)):
            return self.current.keys() == other
        elif isinstance(other, HistorySet):
            return self.current.keys() == other.current.keys()
        return False

    def __lt__(self, other):
        return self < other or self == other

    def __gt__(self, other):
        return self > other or self == other

    def __contains__(self, item):
        return item in self.current

    __le__ = issubset
    __ge__ = issuperset
    __or__ = union
    __and__ = intersection
    __sub__ = difference
    __xor__ = symmetric_difference
    __ior__ = update
    __iand__ = intersection_update
    __isub__ = difference_update
    __ixor__ = symmetric_difference_update
示例#55
0
def test_debug_sequence():
    t = IntervalTree()
    t.verify()
    t.addi(17.89,21.89)
    t.verify()
    t.addi(11.53,16.53)
    t.verify()
    t.removei(11.53,16.53)
    t.verify()
    t.removei(17.89,21.89)
    t.verify()
    t.addi(-0.62,4.38)
    t.verify()
    t.addi(9.24,14.24)
    # t.print_structure()
    # Node<-0.62, depth=2, balance=1>
    #  Interval(-0.62, 4.38)
    # >:  Node<9.24, depth=1, balance=0>
    #      Interval(9.24, 14.24)
    t.verify()

    t.addi(4.0,9.0)  # This line breaks the invariants, leaving an empty node
    # t.print_structure()
    t.verify()
    t.removei(-0.62,4.38)
    t.verify()
    t.removei(9.24,14.24)
    t.verify()
    t.removei(4.0,9.0)
    t.verify()
    t.addi(12.86,17.86)
    t.verify()
    t.addi(16.65,21.65)
    t.verify()
    t.removei(12.86,17.86)