def annotate_gene_and_tss_ids(path_info_list, strand,
                              gene_id_value_obj,
                              tss_id_value_obj):
    # cluster paths to determine gene ids
    cluster_tree = ClusterTree(0,1)
    # map tss positions to unique ids
    tss_pos_id_map = {}
    for i,path_info in enumerate(path_info_list):
        start = path_info.path[0].start
        end = path_info.path[-1].end
        # cluster transcript coordinates
        cluster_tree.insert(start, end, i)
        # map TSS positions to IDs
        tss_pos = end if strand == NEG_STRAND else start
        if tss_pos not in tss_pos_id_map:
            tss_id = tss_id_value_obj.next()
            tss_pos_id_map[tss_pos] = tss_id
        else:
            tss_id = tss_pos_id_map[tss_pos]
        path_info.tss_id = tss_id
    # retrieve transcript clusters and assign gene ids
    for start, end, indexes in cluster_tree.getregions():
        gene_id = gene_id_value_obj.next()
        for i in indexes:
            path_info_list[i].gene_id = gene_id
def annotate_gene_and_tss_ids(path_info_list, strand,
                              gene_id_value_obj,
                              tss_id_value_obj):
    # cluster paths to determine gene ids
    cluster_tree = ClusterTree(0,1)
    # map tss positions to unique ids
    tss_pos_id_map = {}
    for i,path_info in enumerate(path_info_list):
        start = path_info.path[0].start
        end = path_info.path[-1].end
        # cluster transcript coordinates
        cluster_tree.insert(start, end, i)
        # map TSS positions to IDs
        tss_pos = end if strand == NEG_STRAND else start
        if tss_pos not in tss_pos_id_map:
            tss_id = tss_id_value_obj.next()
            tss_pos_id_map[tss_pos] = tss_id
        else:
            tss_id = tss_pos_id_map[tss_pos]
        path_info.tss_id = tss_id
    # retrieve transcript clusters and assign gene ids
    for start, end, indexes in cluster_tree.getregions():
        gene_id = gene_id_value_obj.next()
        for i in indexes:
            path_info_list[i].gene_id = gene_id
Example #3
0
def cluster_gtf_features(gtf_files, source=None):
    # read all features
    chrom_feature_dict = collections.defaultdict(
        lambda: collections.defaultdict(lambda: []))
    for gtf_file in gtf_files:
        logging.debug('Parsing gtf file: %s' % (gtf_file))
        for f in Feature.parse_gtf(gtf_file):
            # bin by chromosome and strand
            chrom_feature_dict[f.chrom][f.strand].append(f)
    # cluster transcripts into genes
    if source is None:
        source = 'merge'
    logging.debug('Clustering transcripts into genes')
    cur_gene_id = 1
    for strand_feature_dict in chrom_feature_dict.itervalues():
        for strand_features in strand_feature_dict.itervalues():
            # initialize each transcript to be in a 'gene' by itself
            cluster_map = {}
            cluster_tree = ClusterTree(0, 1)
            for i, f in enumerate(strand_features):
                cluster_map[i] = set((i, ))
                for start, end in f.exons:
                    cluster_tree.insert(start, end, i)
            for start, end, indexes in cluster_tree.getregions():
                # group transcripts into larger clusters
                new_cluster = set()
                for i in indexes:
                    new_cluster.update(cluster_map[i])
                # reassign transcript clusters to new cluster
                for i in new_cluster:
                    cluster_map[i] = new_cluster
            del cluster_tree
            # now all transcripts are assigned to a gene cluster
            # enumerate all gene clusters
            clusters = set()
            for clust in cluster_map.itervalues():
                clusters.add(frozenset(clust))
            del cluster_map
            # now assign gene ids to each cluster
            for clust in clusters:
                new_gene_id = '%s%011d' % (source, cur_gene_id)
                for i in clust:
                    f = strand_features[i]
                    f.attrs['orig_gene_id'] = f.attrs['gene_id']
                    f.attrs['gene_id'] = new_gene_id
                cur_gene_id += 1
    # output genes
    logging.debug('Writing transcripts')
    for chrom in sorted(chrom_feature_dict):
        strand_feature_dict = chrom_feature_dict[chrom]
        features = []
        for strand_features in strand_feature_dict.itervalues():
            features.extend(strand_features)
        features.sort(key=operator.attrgetter('start'))
        for f in features:
            for gtf_feature in f.to_gtf_features(source=source):
                yield str(gtf_feature)
Example #4
0
def cluster_gtf_features(gtf_files, source=None):
    # read all features
    chrom_feature_dict = collections.defaultdict(lambda: collections.defaultdict(lambda: []))    
    for gtf_file in gtf_files:
        logging.debug('Parsing gtf file: %s' % (gtf_file))
        for f in Feature.parse_gtf(gtf_file):
            # bin by chromosome and strand
            chrom_feature_dict[f.chrom][f.strand].append(f)
    # cluster transcripts into genes
    if source is None:
        source = 'merge'
    logging.debug('Clustering transcripts into genes')
    cur_gene_id = 1
    for strand_feature_dict in chrom_feature_dict.itervalues():
        for strand_features in strand_feature_dict.itervalues():
            # initialize each transcript to be in a 'gene' by itself
            cluster_map = {}
            cluster_tree = ClusterTree(0,1)
            for i,f in enumerate(strand_features):
                cluster_map[i] = set((i,))
                for start,end in f.exons:
                    cluster_tree.insert(start, end, i)
            for start, end, indexes in cluster_tree.getregions():
                # group transcripts into larger clusters
                new_cluster = set()
                for i in indexes:
                    new_cluster.update(cluster_map[i])
                # reassign transcript clusters to new cluster
                for i in new_cluster:
                    cluster_map[i] = new_cluster
            del cluster_tree
            # now all transcripts are assigned to a gene cluster
            # enumerate all gene clusters
            clusters = set()
            for clust in cluster_map.itervalues():
                clusters.add(frozenset(clust))
            del cluster_map
            # now assign gene ids to each cluster
            for clust in clusters:
                new_gene_id = '%s%011d' % (source, cur_gene_id)
                for i in clust:
                    f = strand_features[i]
                    f.attrs['orig_gene_id'] = f.attrs['gene_id']
                    f.attrs['gene_id'] = new_gene_id
                cur_gene_id += 1
    # output genes
    logging.debug('Writing transcripts')
    for chrom in sorted(chrom_feature_dict):
        strand_feature_dict = chrom_feature_dict[chrom]
        features = []
        for strand_features in strand_feature_dict.itervalues():
            features.extend(strand_features)
        features.sort(key=operator.attrgetter('start'))
        for f in features:
            for gtf_feature in f.to_gtf_features(source=source):
                yield str(gtf_feature)
def read_reference_gtf(ref_gtf_file):
    gene_map = {}
    for f in GTFFeature.parse(open(ref_gtf_file)):
        # get gene by id
        gene_id = f.attrs["gene_id"]
        if gene_id not in gene_map:
            g = Gene()
            g.gene_id = gene_id
            g.chrom = f.seqid 
            g.strand = f.strand
            g.gene_start = f.start
            g.gene_end = f.end
            gene_map[gene_id] = g
        else:
            g = gene_map[gene_id]
        # update gene
        g.gene_start = min(g.gene_start, f.start)
        g.gene_end = max(g.gene_end, f.end)
        if f.feature_type == "exon":
            g.exons.add((f.start, f.end))
        elif f.feature_type == "CDS":
            g.is_coding = True
        if "gene_name" in f.attrs:                    
            g.gene_names.add(f.attrs["gene_name"])
        g.annotation_sources.add(f.source)
    logging.info("Sorting genes")
    genes = sorted(gene_map.values(), key=operator.attrgetter('chrom', 'gene_start'))
    del gene_map
    # cluster loci
    logging.debug("Building interval index")
    locus_cluster_trees = collections.defaultdict(lambda: ClusterTree(0,1))
    for i,g in enumerate(genes):
        locus_cluster_trees[g.chrom].insert(g.gene_start, g.gene_end, i)
    locus_trees = collections.defaultdict(lambda: IntervalTree())
    for chrom, cluster_tree in locus_cluster_trees.iteritems(): 
        for locus_start,locus_end,indexes in cluster_tree.getregions():
            # cluster gene exons and add to interval tree
            exon_tree = IntervalTree()
            for i in indexes:
                g = genes[i]
                cluster_tree = ClusterTree(0,1)
                for start,end in g.exons:
                    cluster_tree.insert(start, end, 1)
                # update exons
                exon_clusters = []
                for start,end,indexes in cluster_tree.getregions():
                    exon_clusters.append((start,end))
                g.exons = exon_clusters
                del cluster_tree
                for start,end in g.exons:
                    exon_tree.insert_interval(Interval(start, end, value=g))
            # add to locus interval tree
            locus_trees[chrom].insert_interval(Interval(locus_start, locus_end, value=exon_tree))
    logging.debug("Done indexing reference GTF file")
    return locus_trees
Example #6
0
def cluster_isoforms(transcripts):
    # cluster exons
    cluster_tree = ClusterTree(0,1)
    for t in transcripts:
        for e in t.exons:
            cluster_tree.insert(e.start, e.end, 1)
    exons = []
    for start,end,indexes in cluster_tree.getregions():
        exons.append((start,end))
    del cluster_tree
    return exons
def main():
    logging.basicConfig(
        level=logging.DEBUG,
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
    parser = argparse.ArgumentParser()
    parser.add_argument('gtf_file')
    parser.add_argument('chrom_sizes')
    parser.add_argument("output_prefix")
    args = parser.parse_args()
    # read one locus at a time
    locus_file = args.output_prefix + '.locus.bed'
    intergenic_file = args.output_prefix + '.intergenic.bed'
    intron_file = args.output_prefix + '.intron.bed'
    locus_fileh = open(locus_file, 'w')
    introns = set()
    logging.info('Parsing transcripts by locus')
    for locus_transcripts in parse_gtf(open(args.gtf_file)):
        # find borders of locus
        locus_chrom = locus_transcripts[0].chrom
        locus_start = min(t.start for t in locus_transcripts)
        locus_end = max(t.end for t in locus_transcripts)
        print >> locus_fileh, '\t'.join(
            [locus_chrom, str(locus_start),
             str(locus_end)])
        logging.debug(
            "[LOCUS] %s:%d-%d %d transcripts" %
            (locus_chrom, locus_start, locus_end, len(locus_transcripts)))
        # cluster locus exons
        cluster_tree = ClusterTree(0, 1)
        for t in locus_transcripts:
            # update locus
            for e in t.exons:
                cluster_tree.insert(e.start, e.end, 1)
        exon_clusters = []
        for start, end, indexes in cluster_tree.getregions():
            exon_clusters.append((start, end))
        # get intronic regions
        e1 = exon_clusters[0]
        for j in xrange(1, len(exon_clusters)):
            e2 = exon_clusters[j]
            introns.add((locus_chrom, e1[1], e2[0]))
            e1 = e2
    locus_fileh.close()
    # write introns
    logging.info('Writing introns')
    intron_fileh = open(intron_file, 'w')
    for chrom, start, end in sorted(introns):
        print >> intron_fileh, '\t'.join([chrom, str(start), str(end)])
    intron_fileh.close()
    # take complement of locus file to find intergenic regions
    logging.info('Complementing locus intervals to find intergenic regions')
    args = ['bedtools', 'complement', '-i', locus_file, '-g', args.chrom_sizes]
    with open(intergenic_file, 'w') as f:
        subprocess.call(args, stdout=f)
def main():
    logging.basicConfig(level=logging.DEBUG,
                        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
    parser = argparse.ArgumentParser()
    parser.add_argument('gtf_file')
    parser.add_argument('chrom_sizes')
    parser.add_argument("output_prefix")
    args = parser.parse_args()
    # read one locus at a time
    locus_file = args.output_prefix + '.locus.bed'
    intergenic_file = args.output_prefix + '.intergenic.bed'
    intron_file = args.output_prefix + '.intron.bed'
    locus_fileh = open(locus_file, 'w')
    introns = set()
    logging.info('Parsing transcripts by locus')
    for locus_transcripts in parse_gtf(open(args.gtf_file)):
        # find borders of locus
        locus_chrom = locus_transcripts[0].chrom
        locus_start = min(t.start for t in locus_transcripts)
        locus_end = max(t.end for t in locus_transcripts)
        print >>locus_fileh, '\t'.join([locus_chrom, str(locus_start), str(locus_end)])
        logging.debug("[LOCUS] %s:%d-%d %d transcripts" % 
                      (locus_chrom, locus_start, locus_end, 
                       len(locus_transcripts)))
        # cluster locus exons
        cluster_tree = ClusterTree(0,1)
        for t in locus_transcripts:
            # update locus         
            for e in t.exons:
                cluster_tree.insert(e.start, e.end, 1)
        exon_clusters = []
        for start,end,indexes in cluster_tree.getregions():
            exon_clusters.append((start,end))
        # get intronic regions
        e1 = exon_clusters[0]        
        for j in xrange(1, len(exon_clusters)):
            e2 = exon_clusters[j]
            introns.add((locus_chrom, e1[1], e2[0]))
            e1 = e2
    locus_fileh.close()
    # write introns
    logging.info('Writing introns')
    intron_fileh = open(intron_file, 'w')
    for chrom, start,end in sorted(introns):
        print >>intron_fileh, '\t'.join([chrom, str(start), str(end)])
    intron_fileh.close()
    # take complement of locus file to find intergenic regions
    logging.info('Complementing locus intervals to find intergenic regions')
    args = ['bedtools', 'complement', 
            '-i', locus_file,
            '-g', args.chrom_sizes]
    with open(intergenic_file, 'w') as f:
        subprocess.call(args, stdout=f)
Example #9
0
def get_gtf_metadata(gtf_file, 
                      omit_attrs=None,
                      group_by="gene_id", 
                      feature_type="exon"):
    if omit_attrs is None:
        omit_attrs = []
    # read gtf file and group by gene
    gene_feature_map = collections.defaultdict(lambda: [])
    gene_attrs_set = set()
    for feature in GTFFeature.parse(open(gtf_file)):
        if feature.feature_type != feature_type:
            continue
        feature_id = feature.attrs[group_by]
        gene_feature_map[feature_id].append(feature)
        gene_attrs_set.update(feature.attrs.keys())
    gene_attrs_set.difference_update(omit_attrs)
    gene_attrs_list = sorted(gene_attrs_set)
    metadata_fields = ["tracking_id", "locus", "strand", "num_exons", "transcript_length"] + gene_attrs_list
    metadata_inds = dict((x,i) for i,x in enumerate(metadata_fields))
    metadata_dict = {}
    # output metadata sorted by gene id
    for feature_id,features in gene_feature_map.iteritems():
        # collect attributes for this gene
        attrdict = collections.defaultdict(lambda: set())
        # cluster exons together for each gene
        cluster_tree = ClusterTree(0,1)
        for i,f in enumerate(features):
            cluster_tree.insert(f.start, f.end, i)
            for k,v in f.attrs.iteritems():
                if k in gene_attrs_set:
                    # some attributes have multiple values separated by a comma
                    attrdict[k].update(v.split(','))
        # determine larger exon clusters
        transcript_length = 0
        exon_clusters = []
        for start, end, indexes in cluster_tree.getregions():
            exon_clusters.append((start,end))
            transcript_length += (end - start)
        del cluster_tree
        chrom = features[0].seqid
        locus_start = min(e[0] for e in exon_clusters)
        locus_end = max(e[1] for e in exon_clusters)
        locus_string = "%s:%d-%d" % (chrom, locus_start, locus_end)
        strand = features[0].strand
        num_exons = len(exon_clusters)
        # make metadata row
        metadata = [feature_id, locus_string, strand, num_exons, transcript_length] + ['NA'] * len(gene_attrs_list)
        # get all attributes
        for k,vals in attrdict.iteritems():
            ind = metadata_inds[k]
            metadata[ind] = ','.join(map(str, sorted(vals)))
        metadata_dict[metadata[0]] = metadata
    return metadata_fields, metadata_dict
Example #10
0
def cluster_tss(tss_dict, chrom_sizes, upstream=0, downstream=0):
    trees = collections.defaultdict(lambda: ClusterTree(0,1))
    tss_ids = sorted(tss_dict)
    for i,tss_id in enumerate(tss_ids):
        chrom, strand, tstart, tend = tss_dict[tss_id]
        if strand == NEG_STRAND:
            pstart = tend - 1 - downstream
            pstart = max(tstart, pstart)
            pend = tend - 1 + upstream
            pend = min(pend, chrom_sizes[chrom])
        else:
            pstart = tstart - upstream
            pstart = max(0, pstart)
            pend = tstart + downstream
            pend = min(tend, pend)
        trees[chrom].insert(pstart, pend, i)
    # get clusters
    cur_p_id = 1
    for chrom in sorted(trees):
        tree = trees[chrom]
        for start, end, indexes in tree.getregions():
            p_id = 'P%07d' % (cur_p_id)
            for i in indexes:
                tss_id = tss_ids[i]
                yield tss_id, p_id, chrom, start, end
            cur_p_id += 1
def read_reference_gtf(ref_gtf_file):
    gene_map = {}
    for f in GTFFeature.parse(open(ref_gtf_file)):
        # get gene by id
        gene_id = f.attrs["gene_id"]
        if gene_id not in gene_map:
            g = Gene()
            g.gene_id = gene_id
            g.chrom = f.seqid
            g.strand = f.strand
            g.gene_start = f.start
            g.gene_end = f.end
            gene_map[gene_id] = g
        else:
            g = gene_map[gene_id]
        # update gene
        g.gene_start = min(g.gene_start, f.start)
        g.gene_end = max(g.gene_end, f.end)
        if f.feature_type == "exon":
            g.exons.add((f.start, f.end))
        elif f.feature_type == "CDS":
            g.is_coding = True
        if "gene_name" in f.attrs:
            g.gene_names.add(f.attrs["gene_name"])
        g.annotation_sources.add(f.source)
    logging.info("Sorting genes")
    genes = sorted(gene_map.values(),
                   key=operator.attrgetter('chrom', 'gene_start'))
    del gene_map
    # cluster loci
    logging.debug("Building interval index")
    locus_cluster_trees = collections.defaultdict(lambda: ClusterTree(0, 1))
    locus_trees = collections.defaultdict(lambda: IntervalTree())
    for i, g in enumerate(genes):
        locus_cluster_trees[g.chrom].insert(g.gene_start, g.gene_end, i)
    for chrom, cluster_tree in locus_cluster_trees.iteritems():
        for locus_start, locus_end, indexes in cluster_tree.getregions():
            # cluster gene exons and add to interval tree
            exon_tree = IntervalTree()
            for i in indexes:
                g = genes[i]
                cluster_tree = ClusterTree(0, 1)
                for start, end in g.exons:
                    cluster_tree.insert(start, end, 1)
                # update exons
                exon_clusters = []
                for start, end, indexes in cluster_tree.getregions():
                    exon_clusters.append((start, end))
                g.exons = exon_clusters
                del cluster_tree
                for start, end in g.exons:
                    exon_tree.insert_interval(Interval(start, end, value=g))
            # add to locus interval tree
            locus_trees[chrom].insert_interval(
                Interval(locus_start, locus_end, value=exon_tree))
    logging.debug("Done indexing reference GTF file")
    return locus_trees
Example #12
0
def get_locus_genes(features):
    gene_map = collections.defaultdict(lambda: Gene())
    for f in features:
        # get gene by id
        gene_id = f.attrs["gene_id"]
        if gene_id not in gene_map:
            g = Gene()
            g.gene_id = gene_id
            g.strand = f.strand
            g.gene_start = f.start
            g.gene_end = f.end
            gene_map[gene_id] = g
        else:
            g = gene_map[gene_id]
        # update gene
        g.gene_start = min(g.gene_start, f.start)
        g.gene_end = max(g.gene_end, f.end)
        if f.feature_type == "exon":
            g.exons.add((f.start, f.end))
        elif f.feature_type == "CDS":
            g.is_coding = True
        if "gene_name" in f.attrs:                    
            g.gene_names.add(f.attrs["gene_name"])
        g.annotation_sources.add(f.source)
    for g in gene_map.itervalues():
        # cluster gene exons
        cluster_tree = ClusterTree(0,1)
        for start,end in g.exons:
            cluster_tree.insert(start, end, 1)
        # update exons
        exon_clusters = []
        for start,end,indexes in cluster_tree.getregions():
            exon_clusters.append((start,end))
        g.exons = exon_clusters
        del cluster_tree
    return sorted(gene_map.values(), key=operator.attrgetter('gene_start'))
Example #13
0
def build_locus_trees(gtf_file):
    transcripts = []
    locus_cluster_trees = collections.defaultdict(lambda: ClusterTree(0,1))
    for locus_transcripts in parse_gtf(open(gtf_file)):
        for t in locus_transcripts:
            is_ref = bool(int(t.attrs[GTFAttr.REF]))
            if not is_ref:
                continue
            i = len(transcripts)
            transcripts.append(t)
            locus_cluster_trees[t.chrom].insert(t.start, t.end, i)
    # build interval trees of loci
    locus_trees = collections.defaultdict(lambda: IntervalTree())
    for chrom, cluster_tree in locus_cluster_trees.iteritems():
        for locus_start, locus_end, indexes in cluster_tree.getregions():
            for i in indexes:
                locus_transcripts = [transcripts[i] for i in indexes]
                locus_trees[chrom].insert_interval(Interval(locus_start, locus_end, value=locus_transcripts))
    return locus_trees
Example #14
0
def main():
    logging.basicConfig(
        level=logging.DEBUG,
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
    parser = argparse.ArgumentParser()
    parser.add_argument('bed_file')
    args = parser.parse_args()

    trees = collections.defaultdict(lambda: ClusterTree(0, 1))
    for f in BEDFeature.parse(open(args.bed_file)):
        tree = trees[f.chrom]
        for start, end in f.exons:
            tree.insert(start, end, 1)
    footprint = 0
    for chrom in sorted(trees):
        chromprint = 0
        tree = trees[chrom]
        for start, end, indexes in tree.getregions():
            chromprint += (end - start)
        #print chrom, chromprint
        footprint += chromprint
    print 'total', footprint
def partition_transcripts_by_strand(transcripts):
    """
    uses information from stranded transcripts to infer strand for 
    unstranded transcripts
    """
    def add_transcript(t, nodes_iter, transcript_maps, node_data):
        for n in nodes_iter:
            node_data[n]['scores'][t.strand] += t.score
        t_id = t.attrs[GTFAttr.TRANSCRIPT_ID]         
        transcript_maps[t.strand][t_id] = t
    # divide transcripts into independent regions of
    # transcription with a single entry and exit point    
    boundaries = find_exon_boundaries(transcripts)
    node_data_func = lambda: {'ref_strands': [False, False],
                              'scores': [0.0, 0.0, 0.0]}
    node_data = collections.defaultdict(node_data_func)
    strand_transcript_maps = [{}, {}, {}]
    strand_ref_transcripts = [[], []]
    unresolved_transcripts = []
    for t in transcripts:
        is_ref = bool(int(t.attrs.get(GTFAttr.REF, "0")))
        if is_ref:
            # label nodes by ref strand
            for n in split_exons(t,boundaries):
                node_data[n]['ref_strands'][t.strand] = True
            strand_ref_transcripts[t.strand].append(t)
        elif t.strand != NO_STRAND:
            add_transcript(t, split_exons(t, boundaries), 
                           strand_transcript_maps, node_data)
        else:
            unresolved_transcripts.append(t)
    # resolve unstranded transcripts
    logging.debug("\t\t%d unstranded transcripts" % 
                  (len(unresolved_transcripts)))
    # keep track of remaining unresolved nodes
    unresolved_nodes = set()
    if len(unresolved_transcripts) > 0:
        resolved = []
        still_unresolved_transcripts = []
        for t in unresolved_transcripts:
            nodes = list(split_exons(t,boundaries))            
            t.strand = resolve_strand(nodes, node_data)
            if t.strand != NO_STRAND:
                resolved.append(t)
            else:
                unresolved_nodes.update(nodes)
                still_unresolved_transcripts.append(t)
        for t in resolved:
            add_transcript(t, split_exons(t, boundaries), 
                           strand_transcript_maps, node_data)
        unresolved_transcripts = still_unresolved_transcripts
    if len(unresolved_transcripts) > 0:
        logging.debug("\t\t%d unresolved transcripts" % 
                      (len(unresolved_transcripts)))
        # if there are still unresolved transcripts then we can try to
        # extrapolate and assign strand to clusters of nodes at once, as
        # long as some of the nodes have a strand assigned
        # cluster unresolved nodes
        unresolved_nodes = sorted(unresolved_nodes)
        cluster_tree = ClusterTree(0,1)
        for i,n in enumerate(unresolved_nodes):
            cluster_tree.insert(n[0], n[1], i)
        # try to assign strand to clusters of nodes
        node_strand_map = {}
        for start, end, indexes in cluster_tree.getregions():
            nodes = [unresolved_nodes[i] for i in indexes]
            strand = resolve_strand(nodes, node_data)
            for n in nodes:
                node_strand_map[n] = strand
        # for each transcript assign strand to the cluster with 
        # the best overlap
        unresolved_count = 0
        for t in unresolved_transcripts:
            strand_bp = [0,0]
            nodes = list(split_exons(t, boundaries))        
            for n in nodes:
                strand = node_strand_map[n]
                if strand != NO_STRAND:
                    strand_bp[strand] += (n[1] - n[0])
            total_strand_bp = sum(strand_bp)
            if total_strand_bp > 0:
                if strand_bp[POS_STRAND] >= strand_bp[NEG_STRAND]:
                    t.strand = POS_STRAND
                else:
                    t.strand = NEG_STRAND
            else:
                unresolved_count += 1
            add_transcript(t, nodes, strand_transcript_maps, node_data)
        logging.debug("\t\tCould not resolve %d transcripts" % 
                      (unresolved_count))
        del cluster_tree    
    return strand_transcript_maps, strand_ref_transcripts
def partition_transcripts_by_strand(transcripts):
    """
    uses information from stranded transcripts to infer strand for 
    unstranded transcripts
    """
    def add_transcript(t, nodes_iter, transcript_lists, node_data):
        for n in nodes_iter:
            node_data[n]['scores'][t.strand] += t.score
        t_id = t.attrs[GTFAttr.TRANSCRIPT_ID]
        transcript_lists[t.strand].append(t)

    # divide transcripts into independent regions of
    # transcription with a single entry and exit point
    boundaries = find_exon_boundaries(transcripts)
    node_data_func = lambda: {
        'ref_strands': [False, False],
        'scores': [0.0, 0.0, 0.0]
    }
    node_data = collections.defaultdict(node_data_func)
    strand_transcript_lists = [[], [], []]
    strand_ref_transcripts = [[], []]
    unresolved_transcripts = []
    for t in transcripts:
        is_ref = bool(int(t.attrs.get(GTFAttr.REF, "0")))
        if is_ref:
            # label nodes by ref strand
            for n in split_exons(t, boundaries):
                node_data[n]['ref_strands'][t.strand] = True
            strand_ref_transcripts[t.strand].append(t)
        elif t.strand != NO_STRAND:
            add_transcript(t, split_exons(t, boundaries),
                           strand_transcript_lists, node_data)
        else:
            unresolved_transcripts.append(t)
    # resolve unstranded transcripts
    logging.debug("\t\t%d unstranded transcripts" %
                  (len(unresolved_transcripts)))
    # keep track of remaining unresolved nodes
    unresolved_nodes = set()
    if len(unresolved_transcripts) > 0:
        resolved = []
        still_unresolved_transcripts = []
        for t in unresolved_transcripts:
            nodes = list(split_exons(t, boundaries))
            t.strand = resolve_strand(nodes, node_data)
            if t.strand != NO_STRAND:
                resolved.append(t)
            else:
                unresolved_nodes.update(nodes)
                still_unresolved_transcripts.append(t)
        for t in resolved:
            add_transcript(t, split_exons(t, boundaries),
                           strand_transcript_lists, node_data)
        unresolved_transcripts = still_unresolved_transcripts
    if len(unresolved_transcripts) > 0:
        logging.debug("\t\t%d unresolved transcripts" %
                      (len(unresolved_transcripts)))
        # if there are still unresolved transcripts then we can try to
        # extrapolate and assign strand to clusters of nodes at once, as
        # long as some of the nodes have a strand assigned
        # cluster unresolved nodes
        unresolved_nodes = sorted(unresolved_nodes)
        cluster_tree = ClusterTree(0, 1)
        for i, n in enumerate(unresolved_nodes):
            cluster_tree.insert(n[0], n[1], i)
        # try to assign strand to clusters of nodes
        node_strand_map = {}
        for start, end, indexes in cluster_tree.getregions():
            nodes = [unresolved_nodes[i] for i in indexes]
            strand = resolve_strand(nodes, node_data)
            for n in nodes:
                node_strand_map[n] = strand
        # for each transcript assign strand to the cluster with
        # the best overlap
        unresolved_count = 0
        for t in unresolved_transcripts:
            strand_bp = [0, 0]
            nodes = list(split_exons(t, boundaries))
            for n in nodes:
                strand = node_strand_map[n]
                if strand != NO_STRAND:
                    strand_bp[strand] += (n[1] - n[0])
            total_strand_bp = sum(strand_bp)
            if total_strand_bp > 0:
                if strand_bp[POS_STRAND] >= strand_bp[NEG_STRAND]:
                    t.strand = POS_STRAND
                else:
                    t.strand = NEG_STRAND
            else:
                unresolved_count += 1
            add_transcript(t, nodes, strand_transcript_lists, node_data)
        logging.debug("\t\tCould not resolve %d transcripts" %
                      (unresolved_count))
        del cluster_tree
    return strand_transcript_lists, strand_ref_transcripts