def match_record_to_tree(self, r):
        """
        r --- GMAPRecord
        tree --- dict of chromosome --> strand --> IntervalTree

        If exact match (every exon junction) or 5' truncated (allow_5merge is True), return the matching GMAPRecord
        Otherwise return None
        *NOTE*: the tree should be non-redundant so can return as soon as exact match is found!
        """
        matches = self.tree[r.chr][r.strand].find(r.start, r.end)
        for r2 in matches:
            r.segments = r.ref_exons
            r2.segments = r2.ref_exons
            if compare_junctions.compare_junctions(r, r2, internal_fuzzy_max_dist=self.internal_fuzzy_max_dist) == 'exact': # is a match!
                return r2
            elif self.allow_5merge: # check if the shorter one is a subset of the longer one
                if len(r.segments) > len(r2.segments):
                    a, b = r, r2
                else:
                    a, b = r2, r
                # a is the longer one, b is the shorter one
                if compare_junctions.compare_junctions(b, a, internal_fuzzy_max_dist=self.internal_fuzzy_max_dist) == 'subset':
                    # we only know that a is a subset of b, verify that it is actually 5' truncated (strand-sensitive!)
                    # if + strand, last exon of a should match last exon of b
                    # if - strand, first exon of a should match first exon of b
                    if (r.strand == '+' and compare_junctions.overlaps(a.segments[-1], b.segments[-1])) or \
                       (r.strand == '-' and compare_junctions.overlaps(a.segments[0], b.seq_exons[0])):
                        return r2

        return None
    def match_record_to_tree(self, r):
        """
        r --- GMAPRecord
        tree --- dict of chromosome --> strand --> IntervalTree

        If exact match (every exon junction) or 5' truncated (allow_5merge is True), return the matching GMAPRecord
        Otherwise return None
        *NOTE*: the tree should be non-redundant so can return as soon as exact match is found!
        """
        matches = self.tree[r.chr][r.strand].find(r.start, r.end)
        for r2 in matches:
            r.segments = r.ref_exons
            r2.segments = r2.ref_exons
            if compare_junctions.compare_junctions(r, r2, internal_fuzzy_max_dist=self.internal_fuzzy_max_dist) == 'exact': # is a match!
                return r2
            elif self.allow_5merge: # check if the shorter one is a subset of the longer one
                if len(r.segments) > len(r2.segments):
                    a, b = r, r2
                else:
                    a, b = r2, r
                # a is the longer one, b is the shorter one
                if compare_junctions.compare_junctions(b, a, internal_fuzzy_max_dist=self.internal_fuzzy_max_dist) == 'subset':
                    # we only know that a is a subset of b, verify that it is actually 5' truncated (strand-sensitive!)
                    # if + strand, last exon of a should match last exon of b
                    # if - strand, first exon of a should match first exon of b
                    if (r.strand == '+' and compare_junctions.overlaps(a.segments[-1], b.segments[-1])) or \
                       (r.strand == '-' and compare_junctions.overlaps(a.segments[0], b.seq_exons[0])):
                        return r2

        return None
    def match_record_to_tree(self, r):
        """
        r --- GMAPRecord
        tree --- dict of chromosome --> strand --> IntervalTree

        If exact match (every exon junction) or 5' truncated (allow_5merge is True), YIELD the matching GMAPRecord(s)
        *NOTE/UPDATE*: could have multiple matches! )
        """
        #if r.chr=='chr17' and r.start > 39604000:
        #    pdb.set_trace()
        matches = self.tree[r.chr][r.strand].find(r.start, r.end)
        for r2 in matches:
            r.segments = r.ref_exons
            r2.segments = r2.ref_exons
            n1 = len(r.segments)
            n2 = len(r2.segments)

            three_end_is_match = self.max_3_diff is None or \
                        (r.strand=='+' and abs(r.end-r2.end)<=self.max_3_diff) or \
                        (r.strand=='-' and abs(r.start-r2.start)<=self.max_3_diff)

            last_junction_match = False
            if n1 == 1:
                if n2 == 1: last_junction_match = True
                else: last_junction_match = False
            else:
                if n2 == 1: last_junction_match = False
                else:
                    if r.strand == '+':
                        last_junction_match = (abs(r.segments[-1].start-r2.segments[-1].start) <= self.internal_fuzzy_max_dist) and \
                                              (abs(r.segments[0].end-r2.segments[0].end) <= self.internal_fuzzy_max_dist)
                    else:
                        last_junction_match = (abs(r.segments[0].end-r2.segments[0].end) <= self.internal_fuzzy_max_dist) and \
                                              (abs(r.segments[1].start-r2.segments[1].start) <= self.internal_fuzzy_max_dist)

            if compare_junctions.compare_junctions(
                    r, r2, internal_fuzzy_max_dist=self.internal_fuzzy_max_dist
            ) == 'exact':  # is a match!
                if three_end_is_match:
                    yield r2
            elif self.allow_5merge:  # check if the shorter one is a subset of the longer one
                if len(r.segments) > len(r2.segments):
                    a, b = r, r2
                else:
                    a, b = r2, r
                # a is the longer one, b is the shorter one
                if compare_junctions.compare_junctions(
                        b, a, internal_fuzzy_max_dist=self.
                        internal_fuzzy_max_dist) == 'subset':
                    # we only know that a is a subset of b, verify that it is actually 5' truncated (strand-sensitive!)
                    # if + strand, last junction of (a,b) should match and 3' end not too diff
                    # if - strand, first exon of a should match first exon of b AND the next exon don't overlap
                    if three_end_is_match and last_junction_match:
                        yield r2
    def match_record_to_tree(self, r: GFF.gmapRecord, check_5_dist: bool,
                             check_3_dist: bool) -> List[str]:
        """
        Matching a single record (locus).

        Major diff from non-fusion version:
        1. there could be multiple matches!
        2. no 5merge allowed
        3. additionally checks if the 5'/3' ends don't disagree too much (fusion_max_dist). this is used for fusion junctions.
        4. need to take care that fusions can be multi-chromosome! write output correctly!!!
        """
        matches = self.tree[r.chr][r.strand].find(r.start, r.end)
        result = []
        for r2 in matches:
            r.segments = r.ref_exons
            r2.segments = r2.ref_exons
            if (compare_junctions.compare_junctions(
                    r, r2,
                    internal_fuzzy_max_dist=self.internal_fuzzy_max_dist)
                    == "exact" and
                (not check_5_dist or self.junction_match_check_5(r, r2)) and
                (not check_3_dist
                 or self.junction_match_check_3(r, r2))):  # is a match!
                result.append(r2.seqid)

        return result
示例#5
0
def is_fusion_compatible(r1, r2, max_fusion_point_dist, max_exon_end_dist, allow_extra_5_exons):
    """
    Helper function for: merge_fusion_exons()

    Check that:
    (1) r1, r2 and both in the 5', or both in the 3'
    (2) if single-exon, fusion point must be close by
        if multi-exon, every junction identical (plus below is True)
    (3) if allow_extra_5_exons is False, num exons must be the same
        if allow_extra_5_exons is True, only allow additional 5' exons
    """
#    _ids = 'i1a_c1603/f67p459/1248,i1b_c19881/f7p368/1235,newClontech_i0HQ|c18279/f6p24/1229,i2b_c22046/f2p494/2157,i2a_c4714/f10p554/2152'.split(',')
#    if r1.qID in _ids or r2.qID in _ids:
#        pdb.set_trace()
    # first need to figure out ends
    # also check that both are in the 5' portion of r1 and r2
    assert r1.flag.strand == r2.flag.strand
    if r1.qStart <= .5*r1.qLen: # in the 5' portion of r1
        if r2.qStart > .5*r2.qLen: # in the 3' portion, reject
            return False
        in_5_portion = True
    else: # in the 3' portion of r1
        if r2.qStart <= .5*r2.qLen:
            return False
        in_5_portion = False
    plus_is_5end = (r1.flag.strand == '+')

    r1.strand = r1.flag.strand
    r2.strand = r2.flag.strand
    type = compare_junctions(r1, r2)
    if type == 'exact':
        if len(r1.segments) == 1:
            if len(r2.segments) == 1:
                # single exon case, check fusion point is close enough
                if in_5_portion and plus_is_5end: dist = abs(r1.sStart - r2.sStart)
                else: dist = abs(r1.sEnd - r2.sEnd)
                return dist <= max_fusion_point_dist
            else:
                raise Exception("Not possible case for multi-exon transcript and " + \
                        "single-exon transcript to be exact!")
        else: # multi-exon case, must be OK
            return True
    elif type == 'super' or type == 'subset':
        if allow_extra_5_exons:
            # check that the 3' junction is identical
            # also check that the 3' end is relatively close
            if in_5_portion and plus_is_5end:
                if abs(r1.segments[-1].start - r2.segments[-1].start) > max_exon_end_dist: return False
                if abs(r1.segments[-1].end - r2.segments[-1].end) > max_fusion_point_dist: return False
                return True
            elif in_5_portion and (not plus_is_5end):
                if abs(r1.segments[0].end - r2.segments[0].end) > max_exon_end_dist: return False
                if abs(r1.segments[0].start - r2.segments[0].start) > max_fusion_point_dist: return False
                return True
            else:
                return False
        else: # not OK because number of exons must be the same
            return False
    else: #ex: partial, nomatch, etc...
        return False
    def check_records_match(self, records1, records2):
        """
        records1, records2 are two fusion records.
        They match iff:
        1. same number of records
        2. each record (a loci) matches
        """
        if len(records1)!=len(records2): return False

        i = 0
        for r1, r2 in zip(records1, records2):
            # check: chr, strand, exons match
            if r1.chr!=r2.chr or r1.strand!=r2.strand: return False
            r1.segments = r1.ref_exons
            r2.segments = r2.ref_exons
            if compare_junctions.compare_junctions(r1, r2, internal_fuzzy_max_dist=self.internal_fuzzy_max_dist)!='exact':
                return False
            if i == 0: # first record, only need 3' to agree
                if not self.junction_match_check_3(r1, r2): return False
            elif i == len(records1)-1: #last record, only need 5' to agree
                if not self.junction_match_check_5(r1, r2): return False
            else:
                if not self.junction_match_check_5(r1, r2): return False
                if not self.junction_match_check_3(r1, r2): return False
            i += 1

        return True
    def check_records_match(self, records1, records2):
        """
        records1, records2 are two fusion records.
        They match iff:
        1. same number of records
        2. each record (a loci) matches
        """
        if len(records1)!=len(records2): return False

        i = 0
        for r1, r2 in zip(records1, records2):
            # check: chr, strand, exons match
            if r1.chr!=r2.chr or r1.strand!=r2.strand: return False
            r1.segments = r1.ref_exons
            r2.segments = r2.ref_exons
            if compare_junctions.compare_junctions(r1, r2, internal_fuzzy_max_dist=self.internal_fuzzy_max_dist)!='exact':
                return False
            if i == 0: # first record, only need 3' to agree
                if not self.junction_match_check_3(r1, r2): return False
            elif i == len(records1)-1: #last record, only need 5' to agree
                if not self.junction_match_check_5(r1, r2): return False
            else:
                if not self.junction_match_check_5(r1, r2): return False
                if not self.junction_match_check_3(r1, r2): return False
            i += 1

        return True
示例#8
0
def is_fusion_compatible(r1, r2, max_fusion_point_dist, max_exon_end_dist, allow_extra_5_exons):
    """
    Helper function for: merge_fusion_exons()

    Check that:
    (1) r1, r2 and both in the 5', or both in the 3'
    (2) if single-exon, fusion point must be close by
        if multi-exon, every junction identical (plus below is True)
    (3) if allow_extra_5_exons is False, num exons must be the same
        if allow_extra_5_exons is True, only allow additional 5' exons
    """
#    _ids = 'i1a_c1603/f67p459/1248,i1b_c19881/f7p368/1235,newClontech_i0HQ|c18279/f6p24/1229,i2b_c22046/f2p494/2157,i2a_c4714/f10p554/2152'.split(',')
#    if r1.qID in _ids or r2.qID in _ids:
#        pdb.set_trace()
    # first need to figure out ends
    # also check that both are in the 5' portion of r1 and r2
    assert r1.flag.strand == r2.flag.strand
    if r1.qStart <= .5*r1.qLen: # in the 5' portion of r1
        if r2.qStart > .5*r2.qLen: # in the 3' portion, reject
            return False
        in_5_portion = True
    else: # in the 3' portion of r1
        if r2.qStart <= .5*r2.qLen:
            return False
        in_5_portion = False
    plus_is_5end = (r1.flag.strand == '+')

    r1.strand = r1.flag.strand
    r2.strand = r2.flag.strand
    type = compare_junctions(r1, r2)
    if type == 'exact':
        if len(r1.segments) == 1:
            if len(r2.segments) == 1:
                # single exon case, check fusion point is close enough
                if in_5_portion and plus_is_5end: dist = abs(r1.sStart - r2.sStart)
                else: dist = abs(r1.sEnd - r2.sEnd)
                return dist <= max_fusion_point_dist
            else:
                raise Exception, "Not possible case for multi-exon transcript and " + \
                        "single-exon transcript to be exact!"
        else: # multi-exon case, must be OK
            return True
    elif type == 'super' or type == 'subset':
        if allow_extra_5_exons:
            # check that the 3' junction is identical
            # also check that the 3' end is relatively close
            if in_5_portion and plus_is_5end:
                if abs(r1.segments[-1].start - r2.segments[-1].start) > max_exon_end_dist: return False
                if abs(r1.segments[-1].end - r2.segments[-1].end) > max_fusion_point_dist: return False
                return True
            elif in_5_portion and (not plus_is_5end):
                if abs(r1.segments[0].end - r2.segments[0].end) > max_exon_end_dist: return False
                if abs(r1.segments[0].start - r2.segments[0].start) > max_fusion_point_dist: return False
                return True
            else:
                return False
        else: # not OK because number of exons must be the same
            return False
    else: #ex: partial, nomatch, etc...
        return False
    def match_record_to_tree(self, r, check_5_dist, check_3_dist):
        """
        Matching a single record (locus).

        Major diff from non-fusion version:
        1. there could be multiple matches!
        2. no 5merge allowed
        3. additionally checks if the 5'/3' ends don't disagree too much (fusion_max_dist). this is used for fusion junctions.
        """
        matches = self.tree[r.chr][r.strand].find(r.start, r.end)
        result = []
        for r2 in matches:
            r.segments = r.ref_exons
            r2.segments = r2.ref_exons
            if compare_junctions.compare_junctions(r, r2, internal_fuzzy_max_dist=self.internal_fuzzy_max_dist) == 'exact' and \
                    (not check_5_dist or self.junction_match_check_5(r, r2)) and \
                    (not check_3_dist or self.junction_match_check_3(r, r2)): # is a match!
                result.append(r2.seqid)

        return result
    def match_record_to_tree(self, r, check_5_dist, check_3_dist):
        """
        Matching a single record (locus).

        Major diff from non-fusion version:
        1. there could be multiple matches!
        2. no 5merge allowed
        3. additionally checks if the 5'/3' ends don't disagree too much (fusion_max_dist). this is used for fusion junctions.
        """
        matches = self.tree[r.chr][r.strand].find(r.start, r.end)
        result = []
        for r2 in matches:
            r.segments = r.ref_exons
            r2.segments = r2.ref_exons
            if compare_junctions.compare_junctions(r, r2, internal_fuzzy_max_dist=self.internal_fuzzy_max_dist) == 'exact' and \
                    (not check_5_dist or self.junction_match_check_5(r, r2)) and \
                    (not check_3_dist or self.junction_match_check_3(r, r2)): # is a match!
                result.append(r2.seqid)

        return result
示例#11
0
def filter_out_subsets(recs, internal_fuzzy_max_dist):
    # recs must be sorted by start becuz that's the order they are written
    i = 0
    while i < len(recs) - 1:
        j = i + 1
        while j < len(recs):
            if recs[j].start > recs[i].end:
                break
            recs[i].segments = recs[i].ref_exons
            recs[j].segments = recs[j].ref_exons
            m = compare_junctions.compare_junctions(recs[i], recs[j],
                                                    internal_fuzzy_max_dist)
            if can_merge(m, recs[i], recs[j], internal_fuzzy_max_dist):
                if m == 'super':  # pop recs[j]
                    recs.pop(j)
                else:
                    recs.pop(i)
                    j += 1
            else:
                j += 1
        i += 1
示例#12
0
def filter_out_subsets(recs, internal_fuzzy_max_dist):
    # recs must be sorted by start becuz that's the order they are written
    i = 0
    while i < len(recs)-1:
        no_change = True
        j = i + 1
        while j < len(recs):
            if recs[j].start > recs[i].end: 
                break
            recs[i].segments = recs[i].ref_exons
            recs[j].segments = recs[j].ref_exons
            m = compare_junctions.compare_junctions(recs[i], recs[j], internal_fuzzy_max_dist)
            if can_merge(m, recs[i], recs[j], internal_fuzzy_max_dist):
                if m == 'super': # pop recs[j] 
                    recs.pop(j)
                else:
                    recs.pop(i)
                    no_change = False
            else:
                j += 1
        if no_change: i += 1
def filter_out_subsets(recs: Dict[int, GFF.gmapRecord],
                       internal_fuzzy_max_dist: int) -> None:
    # recs must be sorted by start becuz that's the order they are written
    i = 0
    while i < len(recs) - 1:
        no_change = True
        j = i + 1
        while j < len(recs):
            if recs[j].start > recs[i].end:
                break
            recs[i].segments = recs[i].ref_exons
            recs[j].segments = recs[j].ref_exons
            m = compare_junctions.compare_junctions(recs[i], recs[j],
                                                    internal_fuzzy_max_dist)
            if can_merge(m, recs[i], recs[j], internal_fuzzy_max_dist):
                if m == "super":  # pop recs[j]
                    recs.pop(j)
                else:
                    recs.pop(i)
                    no_change = False
            else:
                j += 1
        if no_change:
            i += 1
def collapse_fuzzy_junctions(gff_filename, group_filename, allow_extra_5exon, internal_fuzzy_max_dist):
    def get_fl_from_id(members):
        try:
            # ex: 13cycle_1Mag1Diff|i0HQ_SIRV_1d1m|c139597/f1p0/178
            return sum(int(_id.split('/')[1].split('p')[0][1:]) for _id in members)
        except ValueError:
            return 0

    def can_merge(m, r1, r2):
        if m == 'exact':
            return True
        else:
            if not allow_extra_5exon:
                return False
        # below is continued only if (a) is 'subset' or 'super' AND (b) allow_extra_5exon is True
        if m == 'subset':
            r1, r2 = r2, r1 #  rotate so r1 is always the longer one
        if m == 'super' or m == 'subset':
            n2 = len(r2.ref_exons)
            # check that (a) r1 and r2 end on same 3' exon, that is the last acceptor site agrees
            # AND (b) the 5' start of r2 is sandwiched between the matching r1 exon coordinates
            if r1.strand == '+':
                return abs(r1.ref_exons[-1].start - r2.ref_exons[-1].start) <= internal_fuzzy_max_dist and \
                    r1.ref_exons[-n2].start <= r2.ref_exons[0].start < r1.ref_exons[-n2].end
            else:
                return abs(r1.ref_exons[0].end - r2.ref_exons[0].end) <= internal_fuzzy_max_dist and \
                    r1.ref_exons[n2-1].start <= r2.ref_exons[-1].end < r1.ref_exons[n2].end
        return False

    d = {}
    recs = defaultdict(lambda: {'+':IntervalTree(), '-':IntervalTree()}) # chr --> strand --> tree
    fuzzy_match = defaultdict(lambda: [])
    for r in GFF.collapseGFFReader(gff_filename):
        d[r.seqid] = r
        has_match = False
        r.segments = r.ref_exons
        for r2 in recs[r.chr][r.strand].find(r.start, r.end):
            r2.segments = r2.ref_exons
            m = compare_junctions.compare_junctions(r, r2, internal_fuzzy_max_dist=internal_fuzzy_max_dist, max_5_diff=args.max_5_diff, max_3_diff=args.max_3_diff)
            if can_merge(m, r, r2):
                fuzzy_match[r2.seqid].append(r.seqid)
                has_match = True
                break
        if not has_match:
            recs[r.chr][r.strand].insert(r.start, r.end, r)
            fuzzy_match[r.seqid] = [r.seqid]

    group_info = {}
    with open(group_filename) as f:
        for line in f:
            pbid, members = line.strip().split('\t')
            group_info[pbid] = [x for x in members.split(',')]

    # pick for each fuzzy group the one that has the most exons (if tie, then most FL)
    keys = fuzzy_match.keys()
    keys.sort(key=lambda x: map(int, x.split('.')[1:]))
    f_gff = open(gff_filename+'.fuzzy', 'w')
    f_group = open(group_filename+'.fuzzy', 'w')
    for k in keys:
        all_members = []
        best_pbid, best_size, best_num_exons = fuzzy_match[k][0], len(group_info[fuzzy_match[k][0]]), len(d[fuzzy_match[k][0]].ref_exons)
        all_members += group_info[fuzzy_match[k][0]]
        for pbid in fuzzy_match[k][1:]:
            # note: get_fl_from_id only works on IsoSeq1 and 2 ID formats, will return 0 if IsoSeq3 format or other
            _size = get_fl_from_id(group_info[pbid])
            _num_exons = len(d[pbid].ref_exons)
            all_members += group_info[pbid]
            if _num_exons > best_num_exons or (_num_exons == best_num_exons and _size > best_size):
                best_pbid, best_size, best_num_exons = pbid, _size, _num_exons
        GFF.write_collapseGFF_format(f_gff, d[best_pbid])
        f_group.write("{0}\t{1}\n".format(best_pbid, ",".join(all_members)))
    f_gff.close()
    f_group.close()

    return fuzzy_match
示例#15
0
def collapse_fuzzy_junctions(gff_filename, group_filename, allow_extra_5exon,
                             internal_fuzzy_max_dist):
    def get_fl_from_id(members):
        # ex: 13cycle_1Mag1Diff|i0HQ_SIRV_1d1m|c139597/f1p0/178
        return sum(int(_id.split('/')[1].split('p')[0][1:]) for _id in members)

    def can_merge(m, r1, r2):
        if m == 'exact':
            return True
        else:
            if not allow_extra_5exon:
                return False
        # below is continued only if (a) is 'subset' or 'super' AND (b) allow_extra_5exon is True
        if m == 'subset':
            r1, r2 = r2, r1  #  rotate so r1 is always the longer one
        if m == 'super' or m == 'subset':
            n2 = len(r2.ref_exons)
            # check that (a) r1 and r2 end on same 3' exon, that is the last acceptor site agrees
            # AND (b) the 5' start of r2 is sandwiched between the matching r1 exon coordinates
            if r1.strand == '+':
                return abs(r1.ref_exons[-1].start - r2.ref_exons[-1].start) <= internal_fuzzy_max_dist and \
                    r1.ref_exons[-n2].start <= r2.ref_exons[0].start < r1.ref_exons[-n2].end
            else:
                return abs(r1.ref_exons[0].end - r2.ref_exons[0].end) <= internal_fuzzy_max_dist and \
                    r1.ref_exons[n2-1].start <= r2.ref_exons[-1].end < r1.ref_exons[n2].end
        return False

    d = {}
    recs = defaultdict(lambda: {
        '+': IntervalTree(),
        '-': IntervalTree()
    })  # chr --> strand --> tree
    fuzzy_match = defaultdict(lambda: [])
    for r in GFF.collapseGFFReader(gff_filename):
        d[r.seqid] = r
        has_match = False
        r.segments = r.ref_exons
        for r2 in recs[r.chr][r.strand].find(r.start, r.end):
            r2.segments = r2.ref_exons
            m = compare_junctions.compare_junctions(
                r, r2, internal_fuzzy_max_dist=internal_fuzzy_max_dist)
            if can_merge(m, r, r2):
                fuzzy_match[r2.seqid].append(r.seqid)
                has_match = True
                break
        if not has_match:
            recs[r.chr][r.strand].insert(r.start, r.end, r)
            fuzzy_match[r.seqid] = [r.seqid]

    group_info = {}
    with open(group_filename) as f:
        for line in f:
            pbid, members = line.strip().split('\t')
            group_info[pbid] = [x for x in members.split(',')]

    # pick for each fuzzy group the one that has the most exons (if tie, then most FL)
    keys = fuzzy_match.keys()
    keys.sort(key=lambda x: map(int, x.split('.')[1:]))
    f_gff = open(gff_filename + '.fuzzy', 'w')
    f_group = open(group_filename + '.fuzzy', 'w')
    for k in keys:
        all_members = []
        best_pbid, best_size, best_num_exons = fuzzy_match[k][0], len(
            group_info[fuzzy_match[k][0]]), len(d[fuzzy_match[k][0]].ref_exons)
        all_members += group_info[fuzzy_match[k][0]]
        for pbid in fuzzy_match[k][1:]:
            _size = get_fl_from_id(group_info[pbid])
            _num_exons = len(d[pbid].ref_exons)
            all_members += group_info[pbid]
            if _num_exons > best_num_exons or (_num_exons == best_num_exons
                                               and _size > best_size):
                best_pbid, best_size, best_num_exons = pbid, _size, _num_exons
        GFF.write_collapseGFF_format(f_gff, d[best_pbid])
        f_group.write("{0}\t{1}\n".format(best_pbid, ",".join(all_members)))
    f_gff.close()
    f_group.close()

    return fuzzy_match
    def match_record_to_tree(self, r: GFF.gmapRecord) -> GFF.gmapRecord:
        """
        r --- GMAPRecord
        tree --- dict of chromosome --> strand --> IntervalTree

        If exact match (every exon junction) or 5' truncated (allow_5merge is True), YIELD the matching GMAPRecord(s)
        *NOTE/UPDATE*: could have multiple matches! )
        """
        # if r.chr=='chr17' and r.start > 39604000:
        #    pdb.set_trace()
        matches = self.tree[r.chr][r.strand].find(r.start, r.end)
        for r2 in matches:
            r.segments = r.ref_exons  # the incoming entry
            r2.segments = r2.ref_exons  # an existing entries in the tree
            n1 = len(r.segments)  # how many exons?
            n2 = len(r2.segments)

            three_end_is_match = (
                self.max_3_diff is None
                or (r.strand == "+" and abs(r.end - r2.end) <= self.max_3_diff)
                or (r.strand == "-"
                    and abs(r.start - r2.start) <= self.max_3_diff)
            )  # either nothing, so whether the 3' ends "match" (within reason)

            last_junction_match = False
            if (
                    n1 == 1
            ):  # essentially, if there is just the one exon for both, assume they are the same
                if n2 == 1:
                    last_junction_match = True
                else:
                    last_junction_match = False
            else:
                if n2 == 1:
                    last_junction_match = False
                else:
                    if r.strand == "+":
                        last_junction_match = (
                            abs(r.segments[-1].start - r2.segments[-1].start)
                            <= self.internal_fuzzy_max_dist
                        ) and (
                            abs(r.segments[0].end - r2.segments[0].end) <=
                            self.internal_fuzzy_max_dist
                        )  # match if the difference in the positions of either end of the first and last exon are under a certain distance
                    else:
                        last_junction_match = (
                            abs(r.segments[0].end - r2.segments[0].end) <=
                            self.internal_fuzzy_max_dist) and (
                                abs(r.segments[1].start - r2.segments[1].start)
                                <= self.internal_fuzzy_max_dist)

            # How well do the exon junctions overlap?
            # Exact matches?
            if self.allow_5merge:
                if len(r.segments) > len(r2.segments):
                    a, b = r, r2
                else:
                    a, b = r2, r
            else:
                b, a = r, r2

            # rearranged so that `compare_junctions` is run once, not twice
            junct_compare = compare_junctions.compare_junctions(
                b, a, internal_fuzzy_max_dist=self.internal_fuzzy_max_dist)

            if junct_compare == "exact":  # is a match!
                if three_end_is_match:
                    yield r2
            # check if the shorter one is a subset of the longer one
            elif self.allow_5merge:
                # a is the longer one, b is the shorter one
                if junct_compare == "subset":
                    # we only know that a is a subset of b, verify that it is actually 5' truncated (strand-sensitive!)
                    # if + strand, last junction of (a,b) should match and 3' end not too diff
                    # if - strand, first exon of a should match first exon of b AND the next exon don't overlap
                    if three_end_is_match and last_junction_match:
                        yield r2
def collapse_fuzzy_junctions(
    gff_filename: Union[str, Path],
    group_filename: Union[str, Path],
    allow_extra_5exon: bool,
    internal_fuzzy_max_dist: int,
    max_5_diff: int,
    max_3_diff: int,
) -> defaultdict:
    def can_merge(m, r1, r2):
        if m == "exact":
            return True
        else:
            if not allow_extra_5exon:
                return False
        # below is continued only if (a) is 'subset' or 'super' AND (b) allow_extra_5exon is True
        if m == "subset":
            r1, r2 = r2, r1  # rotate so r1 is always the longer one
        if m == "super" or m == "subset":
            n2 = len(r2.ref_exons)
            # check that (a) r1 and r2 end on same 3' exon, that is the last acceptor site agrees
            # AND (b) the 5' start of r2 is sandwiched between the matching r1 exon coordinates
            if r1.strand == "+":
                return (abs(r1.ref_exons[-1].start - r2.ref_exons[-1].start) <=
                        internal_fuzzy_max_dist and r1.ref_exons[-n2].start <=
                        r2.ref_exons[0].start < r1.ref_exons[-n2].end)
            else:
                return (abs(r1.ref_exons[0].end - r2.ref_exons[0].end) <=
                        internal_fuzzy_max_dist and r1.ref_exons[n2 - 1].start
                        <= r2.ref_exons[-1].end < r1.ref_exons[n2].end)
        return False

    d = {}
    # chr --> strand --> tree
    recs = defaultdict(lambda: {"+": IntervalTree(), "-": IntervalTree()})
    fuzzy_match = defaultdict(lambda: [])
    for r in GFF.collapseGFFReader(gff_filename):
        d[r.seqid] = r
        has_match = False
        r.segments = r.ref_exons
        for r2 in recs[r.chr][r.strand].find(r.start, r.end):
            r2.segments = r2.ref_exons
            m = compare_junctions(
                r,
                r2,
                internal_fuzzy_max_dist=internal_fuzzy_max_dist,
                max_5_diff=max_5_diff,
                max_3_diff=max_3_diff,
            )
            if can_merge(m, r, r2):
                fuzzy_match[r2.seqid].append(r.seqid)
                has_match = True
                break
        if not has_match:
            recs[r.chr][r.strand].insert(r.start, r.end, r)
            fuzzy_match[r.seqid] = [r.seqid]

    group_info = {}
    with open(group_filename) as f:
        for line in f:
            pbid, members = line.strip().split("\t")
            group_info[pbid] = members.split(",")

    # pick for each fuzzy group the one that has the most exons
    keys = list(fuzzy_match.keys())
    keys.sort(key=lambda x: int(x.split(".")[1]))

    with open(f"{gff_filename}.fuzzy",
              "w") as f_gff, open(f"{group_filename}.fuzzy", "w") as f_group:
        for k in keys:
            all_members = []
            best_pbid, best_size, best_num_exons = (
                fuzzy_match[k][0],
                len(group_info[fuzzy_match[k][0]]),
                len(d[fuzzy_match[k][0]].ref_exons),
            )
            all_members += group_info[fuzzy_match[k][0]]
            for pbid in fuzzy_match[k][1:]:
                _num_exons = len(d[pbid].ref_exons)
                _size = len(group_info[pbid])
                all_members += group_info[pbid]
                if _num_exons > best_num_exons or (_num_exons == best_num_exons
                                                   and _size > best_size):
                    best_pbid, best_size, best_num_exons = pbid, _size, _num_exons
            GFF.write_collapseGFF_format(f_gff, d[best_pbid])
            f_group.write(f'{best_pbid}\t{",".join(all_members)}\n')

    return fuzzy_match