def test_merge_equals_reducer_wo_initializer():
    def reducer(old, new):
        return "%s, %s" % (old, new)
    # empty tree
    e = IntervalTree()
    e.merge_equals(data_reducer=reducer)
    e.verify()
    assert not e

    # One Interval in tree, no change
    o = IntervalTree.from_tuples([(1, 2, 'hello')])
    o.merge_equals(data_reducer=reducer)
    o.verify()
    assert len(o) == 1
    assert sorted(o) == [Interval(1, 2, 'hello')]

    # many Intervals in tree, no change
    t = trees['ivs1']()
    orig = trees['ivs1']()
    t.merge_equals(data_reducer=reducer)
    t.verify()
    assert len(t) == len(orig)
    assert t == orig

    # many Intervals in tree, with change
    t = trees['ivs1']()
    orig = trees['ivs1']()
    t.addi(4, 7, 'foo')
    t.merge_equals(data_reducer=reducer)
    t.verify()
    assert len(t) == len(orig)
    assert t != orig
    assert not t.containsi(4, 7, 'foo')
    assert not t.containsi(4, 7, '[4,7)')
    assert t.containsi(4, 7, '[4,7), foo')
def test_merge_equals_reducer_wo_initializer():
    def reducer(old, new):
        return "%s, %s" % (old, new)
    # empty tree
    e = IntervalTree()
    e.merge_equals(data_reducer=reducer)
    e.verify()
    assert not e

    # One Interval in tree, no change
    o = IntervalTree.from_tuples([(1, 2, 'hello')])
    o.merge_equals(data_reducer=reducer)
    o.verify()
    assert len(o) == 1
    assert sorted(o) == [Interval(1, 2, 'hello')]

    # many Intervals in tree, no change
    t = IntervalTree.from_tuples(data.ivs1.data)
    orig = IntervalTree.from_tuples(data.ivs1.data)
    t.merge_equals(data_reducer=reducer)
    t.verify()
    assert len(t) == len(orig)
    assert t == orig

    # many Intervals in tree, with change
    t = IntervalTree.from_tuples(data.ivs1.data)
    orig = IntervalTree.from_tuples(data.ivs1.data)
    t.addi(4, 7, 'foo')
    t.merge_equals(data_reducer=reducer)
    t.verify()
    assert len(t) == len(orig)
    assert t != orig
    assert not t.containsi(4, 7, 'foo')
    assert not t.containsi(4, 7, '[4,7)')
    assert t.containsi(4, 7, '[4,7), foo')
Example #3
0
    def merge_weights(self, run1_for_query, run_2_for_query):
        doc_id2weights = defaultdict(lambda: {
            "run1": defaultdict(lambda: []),
            "run2": defaultdict(lambda: [])
        })

        # If the second run is empty, don't bother to merge
        if not run_2_for_query:
            merged_weights = defaultdict(lambda: {})

            for doc in run1_for_query:
                doc_id = doc["doc_id"]
                fields = set(doc_id2weights[doc_id]["run2"].keys()).union(
                    doc_id2weights[doc_id]["run1"])
                fields = set(fields).difference(set(["doc_id"]))

                for field in fields:
                    merged_weights[doc_id][field] = doc["weights"]

            return merged_weights

        for doc in run1_for_query:
            doc_id2weights[doc["doc_id"]]["run1"] = doc["weights"]
        for doc in run_2_for_query:
            doc_id2weights[doc["doc_id"]]["run2"] = doc["weights"]

        merged_weights = defaultdict(lambda: {})
        for doc_id in doc_id2weights:
            fields = set(doc_id2weights[doc_id]["run2"].keys()).union(
                doc_id2weights[doc_id]["run1"])
            fields = set(fields).difference(set(["doc_id"]))
            for field in fields:
                t = IntervalTree()
                for segment in doc_id2weights[doc_id]["run1"].get(field, []):
                    t.add(
                        Interval(segment[0], segment[1], {"run1": segment[2]}))
                for segment in doc_id2weights[doc_id]["run2"].get(field, []):
                    t.add(
                        Interval(segment[0], segment[1], {"run2": segment[2]}))
                t.split_overlaps()
                t.merge_equals(
                    lambda old_dict, new_dict: old_dict.update(new_dict) or
                    old_dict, {
                        "run1": None,
                        "run2": None
                    })
                merged_intervals = sorted([(i.begin, i.end, i.data)
                                           for i in t],
                                          key=lambda x: (x[0], x[1]))
                merged_weights[doc_id][field] = merged_intervals

        return merged_weights
Example #4
0
def test_merge_equals_reducer_with_initializer():
    def reducer(old, new):
        return old + [new]

    # empty tree
    e = IntervalTree()
    e.merge_equals(data_reducer=reducer, data_initializer=[])
    e.verify()
    assert not e

    # one Interval in tree, no change
    o = IntervalTree.from_tuples([(1, 2, 'hello')])
    o.merge_equals(data_reducer=reducer, data_initializer=[])
    o.verify()
    assert len(o) == 1
    assert sorted(o) == [Interval(1, 2, ['hello'])]

    # many Intervals in tree, no change
    t = IntervalTree.from_tuples(data.ivs1.data)
    orig = IntervalTree.from_tuples(data.ivs1.data)
    t.merge_equals(data_reducer=reducer, data_initializer=[])
    t.verify()
    assert len(t) == len(orig)
    assert t != orig
    assert sorted(t) == [Interval(b, e, [d]) for b, e, d in sorted(orig)]

    # many Intervals in tree, with change
    t = IntervalTree.from_tuples(data.ivs1.data)
    orig = IntervalTree.from_tuples(data.ivs1.data)
    t.addi(4, 7, 'foo')
    t.merge_equals(data_reducer=reducer, data_initializer=[])
    t.verify()
    assert len(t) == len(orig)
    assert t != orig
    assert not t.containsi(4, 7, 'foo')
    assert not t.containsi(4, 7, '[4,7)')
    assert t.containsi(4, 7, ['[4,7)', 'foo'])
def test_merge_equals_reducer_with_initializer():
    def reducer(old, new):
        return old + [new]
    # empty tree
    e = IntervalTree()
    e.merge_equals(data_reducer=reducer, data_initializer=[])
    e.verify()
    assert not e

    # One Interval in tree, no change
    o = IntervalTree.from_tuples([(1, 2, 'hello')])
    o.merge_equals(data_reducer=reducer, data_initializer=[])
    o.verify()
    assert len(o) == 1
    assert sorted(o) == [Interval(1, 2, ['hello'])]

    # many Intervals in tree, no change
    t = IntervalTree.from_tuples(data.ivs1.data)
    orig = IntervalTree.from_tuples(data.ivs1.data)
    t.merge_equals(data_reducer=reducer, data_initializer=[])
    t.verify()
    assert len(t) == len(orig)
    assert t != orig
    assert sorted(t) == [Interval(b, e, [d]) for b, e, d in sorted(orig)]

    # many Intervals in tree, with change
    t = IntervalTree.from_tuples(data.ivs1.data)
    orig = IntervalTree.from_tuples(data.ivs1.data)
    t.addi(4, 7, 'foo')
    t.merge_equals(data_reducer=reducer, data_initializer=[])
    t.verify()
    assert len(t) == len(orig)
    assert t != orig
    assert not t.containsi(4, 7, 'foo')
    assert not t.containsi(4, 7, '[4,7)')
    assert t.containsi(4, 7, ['[4,7)', 'foo'])
Example #6
0
    def intersect_cn_trees(self):
        """
        Gets copy number events from segment trees and adds them to samples

        """
        def get_bands(chrom,
                      start,
                      end,
                      cytoband=os.path.dirname(__file__) +
                      '/supplement_data/cytoBand.txt'):
            """
            Gets cytobands hit by a CN event

            """
            bands = []
            on_c = False
            with open(cytoband, 'r') as f:
                for line in f:
                    row = line.strip('\n').split('\t')
                    if row[0].strip('chr') != str(chrom):
                        if on_c:
                            return bands
                        continue
                    if int(row[1]) <= end and int(row[2]) >= start:
                        bands.append(Cytoband(chrom, row[3]))
                        on_c = True
                    if int(row[1]) > end:
                        return bands

        def merge_cn_events(event_segs,
                            neighbors,
                            R=frozenset(),
                            X=frozenset()):
            """
            Merges copy number events on a single chromosome if they are adjacent and their ccf values are similar

            Args:
                event_segs: set of CN segments represented as tuple(bands, CNs, CCF_hats, CCF_highs, CCF_lows, allele)
                neighbors: dict mapping seg to set of neighbors (segs with similar CCFs)
                R: only populated in recursive calls
                X: only populated in recursive calls

            Returns:
                Generator for merged segs
            """
            is_max = True
            for s in itertools.chain(event_segs, X):
                if isadjacent(s, R):
                    is_max = False
                    break
            if is_max:
                bands = set.union(*(set(b[0]) for b in R))
                cns = next(iter(R))[1]
                ccf_hat = np.zeros(len(self.sample_list))
                ccf_high = np.zeros(len(self.sample_list))
                ccf_low = np.zeros(len(self.sample_list))
                for seg in R:
                    ccf_hat += np.array(seg[2])
                    ccf_high += np.array(seg[3])
                    ccf_low += np.array(seg[4])
                yield (bands, cns, ccf_hat / len(R), ccf_high / len(R),
                       ccf_low / len(R))
            else:
                pivot = (event_segs - neighbors[p] for p in event_segs | X)
                for s in min(pivot, key=len):
                    if isadjacent(s, R):
                        for region in merge_cn_events(event_segs
                                                      & neighbors[s],
                                                      neighbors,
                                                      R=R | {s},
                                                      X=X & neighbors[s]):
                            yield region
                        event_segs = event_segs - {s}
                        X = X | {s}

        def isadjacent(s, R):
            """
            Copy number events are adjacent if the max band of one is the same as
            or adjacent to the min band of the other

            """
            if not R:
                return True
            Rchain = list(itertools.chain(*(b[0] for b in R)))
            minR = min(Rchain)
            maxR = max(Rchain)
            mins = min(s[0])
            maxs = max(s[0])
            if mins >= maxR:
                return mins - maxR <= 1 and mins.band[0] == maxR.band[0]
            elif maxs <= minR:
                return minR - maxs <= 1 and maxs.band[0] == minR.band[0]
            else:
                return False

        c_trees = {}
        n_samples = len(self.sample_list)
        for chrom in list(map(str, range(1, 23))) + ['X', 'Y']:
            tree = IntervalTree()
            for sample in self.sample_list:
                if sample.CnProfile:
                    tree.update(sample.CnProfile[chrom])
            tree.split_overlaps()
            tree.merge_equals(data_initializer=[],
                              data_reducer=lambda a, c: a + [c])
            c_tree = IntervalTree(
                filter(lambda s: len(s.data) == n_samples, tree))
            c_trees[chrom] = c_tree
            event_segs = set()
            for seg in c_tree:
                start = seg.begin
                end = seg.end
                bands = get_bands(chrom, start, end)
                cns_a1 = []
                cns_a2 = []
                ccf_hat_a1 = []
                ccf_hat_a2 = []
                ccf_high_a1 = []
                ccf_high_a2 = []
                ccf_low_a1 = []
                ccf_low_a2 = []
                for i, sample in enumerate(self.sample_list):
                    seg_data = seg.data[i][1]
                    cns_a1.append(seg_data['cn_a1'])
                    cns_a2.append(seg_data['cn_a2'])
                    ccf_hat_a1.append(seg_data['ccf_hat_a1']
                                      if seg_data['cn_a1'] != 1 else 0.)
                    ccf_hat_a2.append(seg_data['ccf_hat_a2']
                                      if seg_data['cn_a2'] != 1 else 0.)
                    ccf_high_a1.append(seg_data['ccf_high_a1']
                                       if seg_data['cn_a1'] != 1 else 0.)
                    ccf_high_a2.append(seg_data['ccf_high_a2']
                                       if seg_data['cn_a2'] != 1 else 0.)
                    ccf_low_a1.append(seg_data['ccf_low_a1']
                                      if seg_data['cn_a1'] != 1 else 0.)
                    ccf_low_a2.append(seg_data['ccf_low_a2']
                                      if seg_data['cn_a2'] != 1 else 0.)
                cns_a1 = np.array(cns_a1)
                cns_a2 = np.array(cns_a2)
                if np.all(cns_a1 == 1):
                    pass
                elif np.all(cns_a1 >= 1) or np.all(cns_a1 <= 1):
                    event_segs.add(
                        (tuple(bands), tuple(cns_a1), tuple(ccf_hat_a1),
                         tuple(ccf_high_a1), tuple(ccf_low_a1), 'a1'))
                else:
                    logging.warning(
                        'Seg with inconsistent event: {}:{}:{}'.format(
                            chrom, seg.begin, seg.end))
                if np.all(cns_a2 == 1):
                    pass
                elif np.all(cns_a2 >= 1) or np.all(cns_a2 <= 1):
                    event_segs.add(
                        (tuple(bands), tuple(cns_a2), tuple(ccf_hat_a2),
                         tuple(ccf_high_a2), tuple(ccf_low_a2), 'a2'))
                else:
                    logging.warning(
                        'Seg with inconsistent event: {}:{}:{}'.format(
                            chrom, seg.begin, seg.end))
            neighbors = {s: set() for s in event_segs}
            for seg1, seg2 in itertools.combinations(event_segs, 2):
                s1_hat = np.array(seg1[2])
                s2_hat = np.array(seg2[2])
                if seg1[1] == seg2[1] and np.all(s1_hat >= np.array(seg2[4])) and np.all(s1_hat <= np.array(seg2[3]))\
                and np.all(s2_hat >= np.array(seg1[4])) and np.all(s2_hat <= np.array(seg1[3])):
                    neighbors[seg1].add(seg2)
                    neighbors[seg2].add(seg1)

            event_cache = []
            if event_segs:
                for bands, cns, ccf_hat, ccf_high, ccf_low in merge_cn_events(
                        event_segs, neighbors):
                    mut_category = 'gain' if sum(cns) > len(
                        self.sample_list) else 'loss'
                    a1 = (mut_category, bands) not in event_cache
                    if a1:
                        event_cache.append((mut_category, bands))
                    self._add_cn_event_to_samples(chrom,
                                                  min(bands),
                                                  max(bands),
                                                  cns,
                                                  mut_category,
                                                  ccf_hat,
                                                  ccf_high,
                                                  ccf_low,
                                                  a1,
                                                  dupe=not a1)
        self.concordant_cn_tree = c_trees
Example #7
0
def gpx_merge_intersect(input_gpx):
    """
    Given an GPX file containing multiple overlapping tracks, merge the parts
    that overlap and create separate tracks for the parts that only partially
    overlap.
    """
    intervals = IntervalTree()

    all_points = set()

    for track in input_gpx.tracks:
        for segment in track.segments:
            bounds = segment.get_time_bounds()
            start = bounds.start_time.timestamp()
            end = bounds.end_time.timestamp()
            if start == end:
                end = end + 1
            print(f"Found track {track.name}")
            segment.extensions.append(f"OriginalName:{track.name}")
            intervals[start:end] = [segment]
            all_points.update(set([p.time for p in segment.points]))

    n_original = len(intervals)
    intervals.split_overlaps()
    for i in intervals:
        print(i)
    n_split = len(intervals)
    intervals.merge_equals(data_reducer=lambda a, b: a + b)
    n_merged = len(intervals)
    print(f"Split {n_original} intervals into {n_split} merged into {n_merged}")

    merged_points = set()

    output_gpx = input_gpx.clone()
    output_gpx.tracks = []

    for i in sorted(intervals):
        print(i)
        start = datetime.datetime.fromtimestamp(i.begin)
        end = datetime.datetime.fromtimestamp(i.end)
        segments = i.data
        merged = merge_track_segments_within_interval(segments, start, end)
        merged_points.update(set([p.time for p in merged.points]))

        original_names = []
        labels = Counter()
        custom_names = []  # names that didn't match the label patterns
        for ext in merged.extensions:
            match = re.match("OriginalName:(.*)", ext)
            if match:
                original_name = match.group(1)
                original_names.append(original_name)
                matching = list(matching_labels(original_name))
                if not matching:
                    print(f"custom name: {original_name}")
                    custom_names.append(original_name)
                else:
                    labels.update(matching)
        merged.extensions = [e for e in merged.extensions if not re.match("OriginalName:(.*)", e)]
        print(f"remaining extensions: {merged.extensions}")

        track = gpxpy.gpx.GPXTrack()
        track.segments.append(merged)

        track.name = " ".join(
            custom_names + [f"{k}: {v}" for k, v in labels.most_common()]
        )
        track.description = "\n".join(sorted(original_names))
        output_gpx.tracks.append(track)

    print(f"final tracks: {len(output_gpx.tracks)}")
    # make sure all of the points in the input made it through to one of the merged segments
    print(f"all: {len(all_points)}, merged: {len(merged_points)}")
    print(f"all points appear in merge: {all_points == merged_points}")

    return output_gpx
def test_merge_equals_empty():
    t = IntervalTree()
    t.merge_equals()
    t.verify()

    assert len(t) == 0
Example #9
0
    def intersect_cn_trees(self):
        def get_bands(chrom,
                      start,
                      end,
                      cytoband=os.path.dirname(__file__) +
                      '/supplement_data/cytoBand.txt'):
            bands = []
            on_c = False
            with open(cytoband, 'r') as f:
                for line in f:
                    row = line.strip('\n').split('\t')
                    if row[0].strip('chr') != str(chrom):
                        if on_c:
                            return bands
                        continue
                    if int(row[1]) <= end and int(row[2]) >= start:
                        bands.append(Cytoband(chrom, row[3]))
                        on_c = True
                    if int(row[1]) > end:
                        return bands

        def merge_cn_events(event_segs,
                            neighbors,
                            R=frozenset(),
                            X=frozenset()):
            is_max = True
            for s in itertools.chain(event_segs, X):
                if isadjacent(s, R):
                    is_max = False
                    break
            if is_max:
                bands = set.union(*(set(b[0]) for b in R))
                cns = next(iter(R))[1]
                ccf_hat = np.zeros(len(self.sample_list))
                ccf_high = np.zeros(len(self.sample_list))
                ccf_low = np.zeros(len(self.sample_list))
                for seg in R:
                    ccf_hat += np.array(seg[2])
                    ccf_high += np.array(seg[3])
                    ccf_low += np.array(seg[4])
                yield (bands, cns, ccf_hat / len(R), ccf_high / len(R),
                       ccf_low / len(R))
            else:
                for s in min(
                    (event_segs - neighbors[p] for p in event_segs.union(X)),
                        key=len):
                    if isadjacent(s, R):
                        for region in merge_cn_events(
                                event_segs.intersection(neighbors[s]),
                                neighbors,
                                R=R.union({s}),
                                X=X.intersection(neighbors[s])):
                            yield region
                        event_segs = event_segs.difference({s})
                        X = X.union({s})

        def isadjacent(s, R):
            if not R:
                return True
            Rchain = list(itertools.chain(*(b[0] for b in R)))
            minR = min(Rchain)
            maxR = max(Rchain)
            mins = min(s[0])
            maxs = max(s[0])
            if mins >= maxR:
                return mins - maxR <= 1 and mins.band[0] == maxR.band[0]
            elif maxs <= minR:
                return minR - maxs <= 1 and maxs.band[0] == minR.band[0]
            else:
                return False

        c_trees = {}
        n_samples = len(self.sample_list)
        for chrom in list(map(str, range(1, 23))) + ['X', 'Y']:
            tree = IntervalTree()
            for sample in self.sample_list:
                if sample.CnProfile:
                    tree.update(sample.CnProfile[chrom])
            tree.split_overlaps()
            tree.merge_equals(data_initializer=[],
                              data_reducer=lambda a, c: a + [c])
            c_tree = IntervalTree(
                filter(lambda s: len(s.data) == n_samples, tree))
            c_trees[chrom] = c_tree
            event_segs = set()
            for seg in c_tree:
                start = seg.begin
                end = seg.end
                bands = get_bands(chrom, start, end)
                cns_a1 = []
                cns_a2 = []
                ccf_hat_a1 = []
                ccf_hat_a2 = []
                ccf_high_a1 = []
                ccf_high_a2 = []
                ccf_low_a1 = []
                ccf_low_a2 = []
                for i, sample in enumerate(self.sample_list):
                    cns_a1.append(seg.data[i][1]['cn_a1'])
                    cns_a2.append(seg.data[i][1]['cn_a2'])
                    ccf_hat_a1.append(seg.data[i][1]['ccf_hat_a1']
                                      if seg.data[i][1]['cn_a1'] != 1 else 0.)
                    ccf_hat_a2.append(seg.data[i][1]['ccf_hat_a2']
                                      if seg.data[i][1]['cn_a2'] != 1 else 0.)
                    ccf_high_a1.append(seg.data[i][1]['ccf_high_a1']
                                       if seg.data[i][1]['cn_a1'] != 1 else 0.)
                    ccf_high_a2.append(seg.data[i][1]['ccf_high_a2']
                                       if seg.data[i][1]['cn_a2'] != 1 else 0.)
                    ccf_low_a1.append(seg.data[i][1]['ccf_low_a1']
                                      if seg.data[i][1]['cn_a1'] != 1 else 0.)
                    ccf_low_a2.append(seg.data[i][1]['ccf_low_a2']
                                      if seg.data[i][1]['cn_a2'] != 1 else 0.)
                cns_a1 = np.array(cns_a1)
                cns_a2 = np.array(cns_a2)
                if np.all(cns_a1 == 1):
                    pass
                elif np.all(cns_a1 >= 1) or np.all(cns_a1 <= 1):
                    event_segs.add(
                        (tuple(bands), tuple(cns_a1), tuple(ccf_hat_a1),
                         tuple(ccf_high_a1), tuple(ccf_low_a1), 'a1'))
                else:
                    logging.warning(
                        'Seg with inconsistent event: {}:{}:{}'.format(
                            chrom, seg.begin, seg.end))
                if np.all(cns_a2 == 1):
                    pass
                elif np.all(cns_a2 >= 1) or np.all(cns_a2 <= 1):
                    event_segs.add(
                        (tuple(bands), tuple(cns_a2), tuple(ccf_hat_a2),
                         tuple(ccf_high_a2), tuple(ccf_low_a2), 'a2'))
                else:
                    logging.warning(
                        'Seg with inconsistent event: {}:{}:{}'.format(
                            chrom, seg.begin, seg.end))
            neighbors = {s: set() for s in event_segs}
            for seg1, seg2 in itertools.combinations(event_segs, 2):
                s1_hat = np.array(seg1[2])
                s2_hat = np.array(seg2[2])
                if seg1[1] == seg2[1] and np.all(s1_hat >= np.array(seg2[4])) and np.all(s1_hat <= np.array(seg2[3]))\
                and np.all(s2_hat >= np.array(seg1[4])) and np.all(s2_hat <= np.array(seg1[3])):
                    neighbors[seg1].add(seg2)
                    neighbors[seg2].add(seg1)

            event_cache = []
            if event_segs:
                for bands, cns, ccf_hat, ccf_high, ccf_low in merge_cn_events(
                        event_segs, neighbors):
                    mut_category = 'gain' if sum(cns) > len(
                        self.sample_list) else 'loss'
                    a1 = (mut_category, bands) not in event_cache
                    if a1:
                        event_cache.append((mut_category, bands))
                    self._add_cn_event_to_samples(chrom,
                                                  min(bands),
                                                  max(bands),
                                                  cns,
                                                  mut_category,
                                                  ccf_hat,
                                                  ccf_high,
                                                  ccf_low,
                                                  a1,
                                                  dupe=not a1)
        self.concordant_cn_tree = c_trees
Example #10
0
class TemporalPathPyObject(PathPyObject):
    """Base class for a temporal object."""
    def __init__(self, uid: Optional[str] = None, **kwargs: Any) -> None:
        """Initialize the temporal object."""

        # initialize the parent class
        super().__init__(uid=uid)

        # default start and end time of the object
        self._start = float('-inf')
        self._end = float('inf')

        # initialize an intervaltree to save events
        self._events = IntervalTree()

        # add new events
        self.event(**kwargs)

        # variable to store changes in the events
        self._len_events = len(self._events)

    def __iter__(self):
        self._clean_events()

        # create generator
        for start, end, attributes in sorted(self._events):
            self._attributes = {**{'start': start, 'end': end}, **attributes}
            yield self
        self._attributes.pop('start', None)
        self._attributes.pop('end', None)

    @singledispatchmethod
    def __getitem__(self, key: Any) -> Any:
        self._clean_events()
        # get the last element
        _, _, last = self.last()
        return last.get(key, None)

    @__getitem__.register(tuple)  # type: ignore
    def _(self, key: tuple) -> Any:
        start, end, _ = _get_start_end(key[0])
        values = {
            k: v
            for _, _, o in sorted(self._events[start:end])
            for k, v in o.items()
        }
        return values.get(key[1], None) if len(key) == 2 else values

    @__getitem__.register(slice)  # type: ignore
    @__getitem__.register(int)  # type: ignore
    @__getitem__.register(float)  # type: ignore
    def _(self, key: Union[int, float, slice]) -> Any:
        start, end, _ = _get_start_end(key)
        self._clean_events()

        # create generator
        for start, end, attributes in sorted(self._events[start:end]):
            self._attributes = {**{'start': start, 'end': end}, **attributes}
            yield self
        self._attributes.pop('start', None)
        self._attributes.pop('end', None)

    @singledispatchmethod
    def __setitem__(self, key: Any, value: Any) -> None:
        self.event(start=self._events.begin(),
                   end=self._events.end(),
                   **{key: value})

    @__setitem__.register(tuple)  # type: ignore
    def _(self, key: tuple, value: Any) -> None:
        start, end, _ = _get_start_end(key[0])
        self.event(start=start, end=end, **{key[1]: value})

    @property
    def start(self):
        """start of the object"""
        return self.attributes.get('start', self._start)

    @property
    def end(self):
        """end of the object"""
        return self.attributes.get('end', self._end)

    def _clean_events(self):
        """helper function to clean events"""

        # BUG: There is a bug in the intervaltree library
        # merge_equals switches old and new data randomly
        def reducer(old, new):
            return {**old, **new}

        if len(self._events) != self._len_events:
            # split overlapping intervals
            self._events.split_overlaps()

            # combine the dict of the overlapping intervals
            self._events.merge_equals(data_reducer=reducer)

            # update the length of the events
            self._len_events = len(self._events)

    def event(self, *args, **kwargs) -> None:
        """Add a temporal event."""

        # check if object is avtive or inactive
        active = kwargs.pop('active', True)

        # get start and end time of the even
        start, end, kwargs = _get_start_end(*args, **kwargs)

        if active:
            self._events[start:end] = kwargs  # type: ignore
            self._attributes = kwargs.copy()
        else:
            self._events.chop(start, end)

        # update start and end times
        self._start = self._events.begin()
        self._end = self._events.end()

    def last(self):
        """return the last added intervall"""
        interval = sorted(self._events)[-1]
        return interval.begin, interval.end, interval.data
Example #11
0
def test_merge_equals_empty():
    t = IntervalTree()
    t.merge_equals()
    t.verify()

    assert len(t) == 0
Example #12
0
    def get_arm_level_cn_events(self):
        n_samples = len(self.sample_list)
        for ckey, (chrom, csize) in enumerate(zip(list(map(str, range(1, 23))) + ['X', 'Y'], CSIZE)):
            centromere = CENT_LOOKUP[ckey + 1]
            tree = IntervalTree()
            for sample in self.sample_list:
                if sample.CnProfile:
                    tree.update(sample.CnProfile[chrom])
            tree.split_overlaps()
            tree.merge_equals(data_initializer=[], data_reducer=lambda a, c: a + [c])
            c_tree = IntervalTree(filter(lambda s: len(s.data) == n_samples, tree))
            event_segs = set()
            for seg in c_tree:
                start = seg.begin
                end = seg.end
                cns_a1 = []
                cns_a2 = []
                ccf_hat_a1 = []
                ccf_hat_a2 = []
                ccf_high_a1 = []
                ccf_high_a2 = []
                ccf_low_a1 = []
                ccf_low_a2 = []
                for i, sample in enumerate(self.sample_list):
                    seg_data = seg.data[i][1]
                    cns_a1.append(seg_data['cn_a1'])
                    cns_a2.append(seg_data['cn_a2'])
                    ccf_hat_a1.append(seg_data['ccf_hat_a1'] if seg_data['cn_a1'] != 1 else 0.)
                    ccf_hat_a2.append(seg_data['ccf_hat_a2'] if seg_data['cn_a2'] != 1 else 0.)
                    ccf_high_a1.append(seg_data['ccf_high_a1'] if seg_data['cn_a1'] != 1 else 0.)
                    ccf_high_a2.append(seg_data['ccf_high_a2'] if seg_data['cn_a2'] != 1 else 0.)
                    ccf_low_a1.append(seg_data['ccf_low_a1'] if seg_data['cn_a1'] != 1 else 0.)
                    ccf_low_a2.append(seg_data['ccf_low_a2'] if seg_data['cn_a2'] != 1 else 0.)
                cns_a1 = np.array(cns_a1)
                cns_a2 = np.array(cns_a2)
                if np.all(cns_a1 == 1):
                    pass
                elif np.all(cns_a1 >= 1) or np.all(cns_a1 <= 1):
                    if start < centromere < end:
                        event_segs.add((start, centromere, 'p', tuple(cns_a1), tuple(ccf_hat_a1), tuple(ccf_high_a1), tuple(ccf_low_a1), 'a1'))
                        event_segs.add((centromere, end, 'q', tuple(cns_a1), tuple(ccf_hat_a1), tuple(ccf_high_a1), tuple(ccf_low_a1), 'a1'))
                    elif end < centromere:
                        event_segs.add((start, end, 'p', tuple(cns_a1), tuple(ccf_hat_a1), tuple(ccf_high_a1), tuple(ccf_low_a1), 'a1'))
                    else:
                        event_segs.add((start, end, 'q', tuple(cns_a1), tuple(ccf_hat_a1), tuple(ccf_high_a1), tuple(ccf_low_a1), 'a1'))
                else:
                    logging.warning('Seg with inconsistent event: {}:{}:{}'.format(chrom, seg.begin, seg.end))
                if np.all(cns_a2 == 1):
                    pass
                elif np.all(cns_a2 >= 1) or np.all(cns_a2 <= 1):
                    if start < centromere < end:
                        event_segs.add((start, centromere, 'p', tuple(cns_a2), tuple(ccf_hat_a2), tuple(ccf_high_a2), tuple(ccf_low_a2), 'a2'))
                        event_segs.add((centromere, end, 'q', tuple(cns_a2), tuple(ccf_hat_a2), tuple(ccf_high_a2), tuple(ccf_low_a2), 'a2'))
                    elif end < centromere:
                        event_segs.add((start, end, 'p', tuple(cns_a2), tuple(ccf_hat_a2), tuple(ccf_high_a2), tuple(ccf_low_a2), 'a2'))
                    else:
                        event_segs.add((start, end, 'q', tuple(cns_a2), tuple(ccf_hat_a2), tuple(ccf_high_a2), tuple(ccf_low_a2), 'a2'))
                else:
                    logging.warning('Seg with inconsistent event: {}:{}:{}'.format(chrom, seg.begin, seg.end))
            neighbors = {s: set() for s in event_segs}
            for seg1, seg2 in itertools.combinations(event_segs, 2):
                s1_hat = np.array(seg1[4])
                s2_hat = np.array(seg2[4])
                if seg1[2] == seg2[2] and seg1[3] == seg2[3] and all(s1_hat >= np.array(seg2[6])) and all(s1_hat <= np.array(seg2[5])) \
                        and all(s2_hat >= np.array(seg1[6])) and all(s2_hat <= np.array(seg1[5])):
                    neighbors[seg1].add(seg2)
                    neighbors[seg2].add(seg1)

            def _BK(P, neighbors, R=frozenset(), X=frozenset()):
                if not P and not X:
                    yield R
                else:
                    for v in P:
                        for r in _BK(P & neighbors[v], neighbors, R=R | {v}, X=X & neighbors[v]):
                            yield r
                        P = P - {v}
                        X = X | {v}

            for clique in _BK(event_segs, neighbors):
                if clique:
                    clique_len = 0
                    clique_ccf_hat = np.zeros(n_samples)
                    clique_ccf_high = np.zeros(n_samples)
                    clique_ccf_low = np.zeros(n_samples)
                    n_segs = 0
                    for seg in clique:
                        clique_len += seg[1] - seg[0]
                        clique_ccf_hat += np.array(seg[4])
                        clique_ccf_high += np.array(seg[5])
                        clique_ccf_low += np.array(seg[6])
                        clique_arm = seg[2]
                        cn_category = 'Arm_gain' if all(np.array(seg[3]) > 1) else 'Arm_loss'
                        local_cn = seg[3]
                        n_segs += 1
                    clique_ccf_hat /= n_segs
                    clique_ccf_high /= n_segs
                    clique_ccf_low /= n_segs
                    arm_len = centromere if clique_arm == 'p' else csize - centromere
                    if clique_len > arm_len * .5:
                        self._add_cn_event_to_samples(chrom, 0, 0, clique_arm, local_cn, cn_category, clique_ccf_hat, clique_ccf_high,
                                                      clique_ccf_low)