Example #1
0
def get_coords( axes = 'gene', 
                rows = None,
                time_val = None,
                spatial_idxs = None,
                ids = None):
    bdnet = nio.getBDTNP()
    gene_matrix = array([v['vals'][:,time_val] 
                         for v in bdnet.values() 
                         if str(time_val + 1) in v['steps']])
    gene_matrix_keys = [k 
                for k in bdnet.keys() if str(time_val +1) in v['steps']]

    if axes == 'gene':
        import scipy.sparse as ssp
        import scipy.sparse.linalg as las
        import scipy.sparse.lil as ll
        adj = ssp.csr_matrix(gene_matrix.T)
        n_c = 3
        U,s, Vh = svd = las.svds(adj, n_c)
        filtered_genes = ll.lil_matrix(U)*ll.lil_matrix(diag(s)) *ll.lil_matrix(Vh)        
        xs_gene  = U[ids,0]
        ys_gene  = U[ids,1]
        zs_gene  = U[ids,2]
        
    elif axes == 'space':
        space_space =array([[ [r[idxs]  for idxs in sidxs]
                              for sidxs in spatial_idxs] 
                            for r in rows])
        space_space  = space_space[:, : , time_val]
    
        xs_gene = space_space[ids, 0]
        ys_gene = space_space[ids, 1]
        zs_gene = space_space[ids, 2]
    return xs_gene, ys_gene, zs_gene
Example #2
0
def check_bdtnp(**kwargs):
    ids = id_map(**mem.sr(kwargs))
    
    btgs = nio.getBDTNP()
    stfs,stgs = nio.getNet()
    
    sxs = dict([(i, {}) for i in ids.keys()])
    for i, elt  in enumerate(ids.iteritems()):
        gid, gene_val = elt
        if stfs.has_key(gid) : sxs[gid]['stf'] = True;
        if btgs.has_key(gid) : sxs[gid]['btg'] = True;
        if stgs.has_key(gid) : sxs[gid]['stg'] = True;

    return sxs
Example #3
0
 def set_term_network( **kwargs):
   nterms = kwargs.get('nterms')
   name = kwargs.get('name')
   if name == 'bdtnp': gene_list = nio.getBDTNP().keys()
   elif name == 'kn':    gene_list = graphs['kn'].nodes()
   grps = term_groups(**mem.sr(kwargs,
                               name = name))
   network = nx.Graph()
   network.add_nodes_from(gene_list)
 
   
   for g in grps:
     edgelist =  [[g1[0],g2[0]]
                  for g1 in g[1] for g2 in g[1]]
     network.add_edges_from( edgelist  )
   return network
Example #4
0
 def set_term_groups(**kwargs):
  nterms = kwargs.get('nterms') 
  if name == 'bdtnp': gene_list = nio.getBDTNP().keys()
  elif name == 'kn':    gene_list = graphs['kn'].nodes()
    
  #GET ALL CONTROLLED VOCAB TERMS APPLYING TO A GIVEN GENE LIST
  terms = [(gname,gt)  for gname in gene_list for gt in gene_terms(gname)
          ]
  all_terms = set([t[1] for t in terms])
  term_groups_tmp =[(k, list(g)) for k, g in 
                    it.groupby(
      sorted(terms, key = lambda x: x[1]),
      key = lambda x: x[1])
                    ]
  
  #SORT THE TERM GROUPS BY GENE COUNT AND ONLY TAKE TOP N
  if nterms == -1: nterms = len(term_groups_tmp)
  term_groups = sorted(term_groups_tmp, 
                       key = lambda x: len(x[1]))[::-1][:nterms]
  return term_groups
Example #5
0
    def set_nx(**kwargs):
        targets = kwargs.get('targets')
        restriction = kwargs.get('restriction')
        nodes_allowed = kwargs.get('nodes_allowed')
 
        if nodes_allowed != None:
            edges = [(tf, tg, float(wt)) 
                     for tg, elt in targets.iteritems() if tg in nodes_allowed
                     for tf, wt  in zip(elt['tfs'], elt['weights']) if tf in nodes_allowed]
        else:
            edges = [(tf, tg, float(wt)) 
                     for tg, elt in targets.iteritems()
                     for tf, wt  in zip(elt['tfs'], elt['weights']) ]
    
        graph = nx.DiGraph()
        graph.add_weighted_edges_from(edges)
        if restriction != 'none':
            if restriction == 'bdtnp':
                allowed_nodes = nio.getBDTNP().keys()
                graph = restricted_graph(graph, allowed_nodes)
            else: raise (Exception('unrecognized restriction {0}'.\
                                     format(restriction)))
        return graph
Example #6
0
def get_clusters(cluster_id = 4, ):
    all_members, ct_data = exp.recall_c2()
    all_members = array(all_members)
    c = all_members[cluster_id]
    c_unq = set(list(c))
    tissues = dict([('t_{0}'.format(i) , 
                     dict(cts = ct_data[equal(c,elt)]))
                    for i, elt in enumerate(c_unq)])
    bdnet = nio.getBDTNP()
    gene_cols, misc_cols, rows, row_nns = bdparse.read()
    spatial_idxs = array([misc_cols['x']['idxs'],
                     misc_cols['y']['idxs'],
                     misc_cols['z']['idxs']])

    #CHOOSE A TIME m
    time_val = cluster_id
    
    #ORGANIZE A BUNCH OF DATA.
    tsrt = sorted([(k,v) for k, v in  tissues.iteritems()],
                  key = lambda x: len(x[1]['cts']))[::-1]
    mods, genes, tfs = get_results(tsrt = tsrt, 
                                   name = 'cluster_{0}'.format(cluster_id))
    return mods, genes,tfs
Example #7
0
def check_network(net_name = 'binding', 
                  dataset_name = 'reinitz',
                  data_ofs = 4,
                  max_edges = -1,
                  node_restriction = 'reinitz'):

    reinitz_keys =set( get_reinitz_data()[1].keys())
    if dataset_name == 'reinitz':
        coords, values = get_reinitz_data(ofs = data_ofs)
    elif dataset_name == 'bdtnp':
        data = nio.getBDTNP()
        meta = nio.getBDTNP(misc = True)
        values =  dict([( k, v['vals'][:,data_ofs] ) for k,v in data.iteritems()]) 
        coords  = array([meta['x']['vals'][:,data_ofs],meta['y']['vals'][:,data_ofs]])
    elif dataset_name == 'tc':
        data = nio.getTC()
        if node_restriction == 'reinitz':
            data = dict([(k,v) for k,v in data.iteritems() if k in reinitz_keys]) 
        #values =  dict([( k, v['vals'][:,data_ofs] ) for k,v in data.iteritems()]) 
        #coords  = array([meta['x']['vals'][:,data_ofs],meta['y']['vals'][:,data_ofs]])
        values = data
    else:
        raise Exception('data set {0} not yet implemented'.format(dataset_name))

    nets = comp.get_graphs()
    if net_name == 'binding':
        network = nets['bn']
    elif net_name == 'unsup':
        network = nets['unsup']
    elif net_name == 'logistic':
        network = nets['logistic']
    elif net_name =='clusters':
        network = get_soheil_network(max_edges = max_edges,
                                     node_restriction = values.keys())
    else:
        raise Exception('type not implemented: {0}'.format(net_name))

    nodes = values.keys()
    nodes_allowed = set(nodes)

    f = myplots.fignum(1,(8,8))
    ax = f.add_subplot(111)
    targets = {}

    edges = []
    
    for n in nodes:
        targets[n] = []
        if n in network:
            targets[n] = nodes_allowed.intersection(network[n].keys())
            
    xax = linspace(-1,1,20)

    edges = list(it.chain(*[[(e,v2) for v2 in v] for e, v in targets.iteritems()]))
    ccofs = [e for e in [ corrcoef(values[tf], values[tg])[0,1] for tf, tg in edges] if not isnan(e)]
    
    count, kde = make_kde(ccofs)
    

    ax.hist(ccofs,xax,label = net_name)
    h =histogram(ccofs,xax)
    ax.fill_between(xax,kde(xax)*max(h[0]),label = net_name,zorder = 1,alpha = .5)



    myplots.maketitle(ax,'edge correlations kde for {0}'.format('\n{2} data (data offset={0})\n(net_name={1})\n(max_edges={3})'
                                                                .format(data_ofs, net_name, dataset_name, max_edges) ),\
                          subtitle = 'n_edges = {0}'.format(len(edges)))
    ax.legend()
    f.savefig(myplots.figpath('network_edge_corrs_data_ofs={0}_net={1}_expr={2}_max_edges={3}'
                              .format(data_ofs,net_name,dataset_name, max_edges)))
Example #8
0
def run_sig(genes, show_disc = False, weighted = True):
    
    genes2 = []
    for g in genes:
        genes2.append((g[0], [[gelt[0], gelt[1]] for gelt in g[1]], g[2]))
    genes = genes2

    modules = [m[0] for m in genes]

    if len(modules[0]) == 2:
        module_type = 'doubles'
    else:
        module_type = 'triples'

    counts = [m[1] for m in genes]
    tgs,tfs = nio.getNet()
    bd = nio.getBDTNP()
    
    nodes_allowed = set(bd.keys())
    cnodes = list(nodes_allowed)

    dnodes = []
    dedges = []
    cedges = []
    cnodes = []
    for m in genes:
        for tginfo in m[1]:
            tg = tginfo[0]
            tg_mcount = tginfo[1]
            dtgnode = '{0}_{1}_mod{2}'.format(tg,tg,m[0])
            ctgnode = '{0}'.format(tg)
            dnodes.append(dtgnode)
            cnodes.append(ctgnode)
            for tf in m[0]:
                dtfnode = '{0}_{1}_mod{2}'.format(tf,tg,m[0])
                ctfnode = '{0}'.format(tf)
                dnodes.append(dtfnode)
                cnodes.append(ctfnode)
                dedges.append((dtfnode, dtgnode,tg_mcount))
                cedges.append((ctfnode, ctgnode,tg_mcount))
                
    nodes_allowed = list(set(cnodes))

    if show_disc:
        dgraph, cgraph = [nx.Graph() for i in range(2)]
        dgraph.add_nodes_from(list(set(dnodes)))
        dgraph.add_weighted_edges_from(list(set(dedges)))
        f = myplots.fignum(4, (8,8))
        ax = f.add_subplot(111)
        pos=nx.graphviz_layout(dgraph,prog="neato")
        # color nodes the same in each connected subgraph
        C=nx.connected_component_subgraphs(dgraph)
        for g in C:
            c=[random.random()]*nx.number_of_nodes(g) # random color...
            nx.draw(g,
                    pos,
                    node_size=40,
                    node_color=c,
                    vmin=0.0,
                    vmax=1.0,
                    with_labels=False
                    )
        figtitle = 'mcmc_disc'
        f.savefig(figtemplate.format(figtitle))
        return


    
    cgraph = nx.DiGraph() 
    cgraph.add_nodes_from(cnodes)
    
    cedgegrps = [(k,list(g)) for k, g in it.groupby(\
            sorted(cedges, key = lambda x: (x[0],x[1])),
            key =  lambda x: (x[0],x[1]))]
    cedges = [ (k[0],k[1], sum([gelt[2] for gelt in g])) 
                 for k,g in cedgegrps] 
                 
    if weighted == False:
        for ce in cedges:
            ce[2] = 1
        
    cgraph.add_weighted_edges_from(list(set(cedges)))


    sfRN = [(tf, tg, float(wt)) 
            for tg, elt in tgs.iteritems() if tg in nodes_allowed
            for tf, wt  in zip(elt['tfs'], elt['weights']) if tf in nodes_allowed]
    fg = nx.DiGraph()
    fg.add_nodes_from(cnodes)
    fg.add_weighted_edges_from(sfRN)


    colors = mycolors.getct(len(cnodes))

    f = myplots.fignum(5, (8,8))
    ax =f.add_subplot(111)
    pos=nx.graphviz_layout(fg,prog="neato")
    # color nodes the same in each connected subgraph
    nx.draw(cgraph,
            pos,
            node_size=100,
            node_color=colors,
            vmin=0.0,
            vmax=1.0,
            with_labels=False,
            alpha = 1.
            )
    ax.set_title('connectivity of MCMC for network {0}'.format(module_type))
    figtitle = 'mcmc_network_{0}{1}'.\
        format(module_type,'' if weighted else 'unweighted')
    f.savefig(figtemplate.format(figtitle))




    f = myplots.fignum(5, (8,8))
    ax =f.add_subplot(111)
    #pos=nx.graphviz_layout(fg,prog="neato")
    # color nodes the same in each connected subgraph
    nx.draw(fg,
            pos,
            node_size=100,
            node_color=colors,
            vmin=0.0,
            vmax=1.0,
            with_labels=False
            )


    ax.set_title('connectivity of reference for network {0}'.format(module_type))

    figtitle = 'mcmc_ref_network_{0}{1}'.\
        format(module_type,'' if weighted else 'unweighted')
    f.savefig(figtemplate.format(figtitle))

                

    graphs = {'mcmc':cgraph,'network':fg}

            
    v0 = graphs.values()
    k0 = graphs.keys()

    for k,g in zip(k0,v0):
        for prc in [1,50,95]:
            thr = percentile([e[2]['weight'] 
                              for e in nx.to_edgelist(g)], prc)
            graphs.update([('{0}_thr{1}%'.format(k,prc),
                             nfu.thr_graph(g,thr))])    
            
    v0 = graphs.values()
    k0 = graphs.keys()

    for k, v in zip(k0,v0):
        tot_edges = len(nx.to_edgelist(fg))
        for n_c in [2,4,6,8,12,20]:
            for max_edges in array([.5,1.,2.]) * tot_edges :
                gfilt = nfu.filter_graph(v, n_c = n_c)
                gfilt = nfu.top_edges(gfilt, max_edges = max_edges)
                gthr = nfu.thr_graph(gfilt, 1e-8)
                graphs.update([('{0}_flt{1}'.format(k,n_c),gfilt)])
                graphs.update([('{0}_flt{1}_thr0'.format(k,n_c),gthr)])
            
    return graphs
Example #9
0
def run2( reset = False,
         base_net = 'kn',
         comp_net = 'fn',
         demand_bdtnp = False):
    bd = nio.getBDTNP()

    ktgs,ktfs = nio.getKNet()
    tgs,tfs = nio.getNet()
    sush = nio.getSush(on_fail = 'compute')
    
    tfset = set(ktfs.keys())
    tgset = set(ktgs.keys())

    
    tg_int = set(tgs.keys()).intersection(ktgs.keys())
    tf_int = set(tfs.keys()).intersection(ktfs.keys())
    
    if demand_bdtnp:
        tg_int = tg_int.intersection(bd.keys())
        tf_int = tf_int.intersection(bd.keys())
    
    if base_net =='kn':
        b_edges = [(tf, tg, float(wt)) 
               for tg, elt in ktgs.iteritems() if tg in tg_int
               for tf, wt  in zip(elt['tfs'], elt['weights']) if tf in tf_int]

    if comp_net == 'fn':
        c_edges = [(tf, tg, float(wt)) 
                for tg, elt in tgs.iteritems() if tg in tg_int
                for tf, wt  in zip(elt['tfs'], elt['weights']) if tf in tf_int]
    elif comp_net == 'sn':
        #Sushmita network with signed edges
        c_edges =  [(tf, tg, float(wt)) 
                 for tg, elt in sush.iteritems() if tg in tg_int
                 for tf, wt  in zip(elt['tfs'], elt['weights']) if tf in tf_int]
    elif comp_net == 'sna':
        #Sushmita network with unsigned edges
        c_edges = [(tf, tg, abs(float(wt))) 
                 for tg, elt in sush.iteritems() if tg in tg_int
                 for tf, wt  in zip(elt['tfs'], elt['weights']) if tf in tf_int]
    elif comp_net == 'kn':
        c_edges = [(tf, tg, float(wt)) 
           for tg, elt in ktgs.iteritems() if tg in tg_int
           for tf, wt  in zip(elt['tfs'], elt['weights']) if tf in tf_int]

    
    ng = 4    
    nodes = array(list(tf_int.union(tg_int)))

    bg = nx.DiGraph()
    bg.add_nodes_from(nodes)
    bg.add_weighted_edges_from(b_edges)

    cg = nx.DiGraph()
    cg.add_nodes_from(nodes)
    cg.add_weighted_edges_from(c_edges)

    cgraphs =  {comp_net:cg}
    v0 = cgraphs.values()
    k0 = cgraphs.keys()
   
 
    for k,g in zip(k0,v0):
        for prc in [10]:
            thr = percentile([e[2]['weight'] 
                              for e in nx.to_edgelist(g)], prc)
            cgraphs.update([('{0}_thr{1:2.2}'.format(k,thr),
                             nfu.thr_graph(g,thr))])    
            gt =  nfu.thr_graph(g,thr)
    v0 = cgraphs.values()
    k0 = cgraphs.keys()

    for k, v in zip(k0,v0):
        tot_edges = len(nx.to_edgelist(v))
        for n_c in [1,2,4]:
            for max_edges in array([.2,.5,1.]) * tot_edges :
                if not 'thr' in k:
                    continue
                gfilt = nfu.filter_graph(v, n_c = n_c)
                gfilt = nfu.top_edges(gfilt, max_edges = max_edges)
                gthr = nfu.thr_graph(gfilt, 1e-8)
                cgraphs.update([('{0}_flt{1}_nedge{2}'.format(k,n_c,max_edges),gfilt)])
                cgraphs.update([('{0}_flt{1}_nedge{2}_thr0'.format(k,n_c,max_edges),gthr)])

    '''
When you don't have hidden variables, networks can be modelled mby information criterion.

In what settings can you incur causality from datasets.

You need a prior to limit the number of arrowsin your graph:
  
The idea: come up with an idea from computational learning theory 
and come up with a model for interventions.

Spatial, genetic, time data to penalize network edges...

Granger causality uses time varying data to 
'''

        

    #nfplots.show(kg,pos,node_color = 'none')
    
    #nfplots.show(fg,pos,node_color = 'white', alpha = .2, with_labels = False)
    
    return bg, cgraphs
Example #10
0
def run( reset = False,
         base_net = 'kn',
         comp_net = 'fn',
         demand_bdtnp = False):
    tgs,tfs = nio.getNet()
    ktgs,ktfs = nio.getKNet()
    bd = nio.getBDTNP()
    #btgs,btfs = nio.getBDTNP()
    sush = nio.getSush(on_fail = 'compute')
    


    tfset = set(ktfs.keys())
    tgset = set(ktgs.keys())

    
    tg_int = set(tgs.keys()).intersection(ktgs.keys())
    tf_int = set(tfs.keys()).intersection(ktfs.keys())
    
    if demand_bdtnp:
        tg_int = tg_int.intersection(bd.keys())
        tf_int = tf_int.intersection(bd.keys())



    sfRN = [(tf, tg, float(wt)) 
            for tg, elt in tgs.iteritems() if tg in tg_int
            for tf, wt  in zip(elt['tfs'], elt['weights']) if tf in tf_int]
    

    kRN = [(tf, tg, float(wt)) 
            for tg, elt in ktgs.iteritems() if tg in tg_int
            for tf, wt  in zip(elt['tfs'], elt['weights']) if tf in tf_int]

    #Sushmita network with signed edges
    suRN =  [(tf, tg, float(wt)) 
            for tg, elt in sush.iteritems() if tg in tg_int
            for tf, wt  in zip(elt['tfs'], elt['weights']) if tf in tf_int]

    #Sushmita network with unsigned edges
    suaRN = [(tf, tg, abs(float(wt))) 
             for tg, elt in sush.iteritems() if tg in tg_int
             for tf, wt  in zip(elt['tfs'], elt['weights']) if tf in tf_int]
             
    edges = [ kRN, sfRN, suRN, suaRN]


    ng = 4
    fg, kg, sug, suag = [nx.DiGraph() for i in range(4)]
    
    
    nodes = array(list(tf_int.union(tg_int)))
    graphs =  {'kg':kg,'fg':fg,'sug':sug,'suag':suag}

    for g, edges in zip(graphs.values(), edges):
        g.add_nodes_from(nodes)
        g.add_weighted_edges_from(edges)
        


    
    for gname in ['fg','suag']:
        for prc in [10,50,75,85,90,95,98,99]:
            thr = percentile([e[2]['weight'] for e in nx.to_edgelist(graphs[gname])], prc)
            graphs.update([('{0}_thr{1:2.2}'.format(gname,thr),
                            nfu.thr_graph(graphs[gname],thr))])    
            
    v0 = graphs.values()
    k0 = graphs.keys()

    tot_edges = len(nx.to_edgelist(graphs['fg']))
    for k, v in zip(k0,v0):
        for n_c in [2,4,8 ,12]:
            for max_edges in array([.5,1.,2.,5.]) * tot_edges :
                if not 'thr' in k:
                    continue
                gfilt = nfu.filter_graph(v, n_c = n_c)
                gfilt = nfu.top_edges(gfilt, max_edges = max_edges)
                gthr = nfu.thr_graph(gfilt, 1e-8)
                graphs.update([('{0}_flt{1}'.format(k,n_c),gfilt)])
                graphs.update([('{0}_flt{1}_thr0'.format(k,n_c),gthr)])
            


        

    #nfplots.show(kg,pos,node_color = 'none')
    
    #nfplots.show(fg,pos,node_color = 'white', alpha = .2, with_labels = False)
    
    return graphs
Example #11
0
def run(num = 2):
    dfile = sio.loadmat(cfg.dataPath('soheil/expression_c4d_n4_intercluster.mat'))


    
    trgs, tfs = nio.getNet()
    bdgenes = nio.getBDTNP()
    bdset = set(bdgenes.keys())


    xs, ys, colors, corrs,lcorrs = [[] for i in range(5)]
    count = 0
    for k, v in bdgenes.iteritems():
        count += 1
        if count < num:
            continue
        if not trgs.has_key(k): continue
        trg = trgs[k]
        fsub = set(tfs.keys()).intersection(bdset)

        gexpr = bdgenes[k]['vals'][::50,4].flatten() #squeeze(dfile[k]) 
        fexpr = [bdgenes[fname]['vals'][::50,4].flatten() for fname in fsub]#[squeeze(dfile[fname]) for fname in fsub]
        
        
        print shape(fexpr)
        if len(fexpr )< 3: continue
        ct = mycolors.getct(len(fexpr))
        for idx, f in enumerate(fexpr):
            c = corrcoef(f, gexpr)[0,1]
            if not isfinite(c): c = 0
            lc = corrcoef(log(f), log(gexpr))[0,1]
            if not isfinite(lc): lc = 0
            corrs.append(c)
            lcorrs.append(lc)
            ys.append(gexpr)
            xs.append(f)
            colors.append([ct[idx]]* len(f))
        break
        if len(xs) > 10000:
            break
    
    cbest = argsort(-1 * abs(array(corrs)))
    

    f = plt.figure(1)
    f.clear()
    ax = f.add_subplot(111)
    inds = argsort(gexpr)

    for idx in cbest[:3]:
        import scipy.signal as ss
        import cb.utils.sigsmooth as sgs
        #k = sgs.gauss_kern(15)[8,:].flatten()
        #xconv = ss.convolve(xs[idx][inds],k) 
        
        xv = ss.medfilt(xs[idx][inds],1)
        yv = ys[idx][inds]
        print corrcoef(xv,yv)[0,1]
        print corrs[idx]
        
        ax.plot(ss.medfilt(xs[idx][inds],1), linewidth = 10, color = colors[idx][0])
        ax.plot(ys[idx][inds], linewidth = 10)
Example #12
0
def view0(modules,
          data_src = 'bdtnp',
          net_src = 'fRN',
          max_rank = 4,
          module_type = 'doubles'):
    '''
    A routine to view the sign of interaction coefficients for 
    a given transcription factor split per-cluster and per-module
    size.

    Designed to be run on the output of view_output.modules()

'''
    #COMPUTE BULK STATISTICS FOR EACH TF
    bd_data = nio.getBDTNP()
    genes = bd_data.keys()

    tfs =sorted(set(it.chain(*[k for k in modules.keys()])))
    tf_net = nx.Graph()
    tf_net.add_nodes_from(tfs)
    tf_edges = it.chain(*[[(e0, e1)
                for e0 in term
                for e1 in term
                if e0 != e1]
                for term in modules.keys()])
    tf_net.add_edges_from(tf_edges)
 
    pos = nx.graphviz_layout(tf_net)
    fig = myplots.fignum(1, (8,8))


    tfnodes = tf_net.nodes()
    tfnames = tfnodes

    all_coefs = \
        list(it.chain(*[t['coefs'] for t in modules.values()]))
    cstd = std(all_coefs)
    def colorfun(coefs):
        coef = median(coefs)
        arr = array([1.,0.,0.]) if coef < 0\
            else array([0.,0.,1.])
        return arr*min([1, abs(coef/cstd*2)])

    def widthfun(coefs, maxwid):
        return max([1,5.* len(coefs) /maxwid])

        
  

    for tf in tfs:
        fig.clf()
        ax = fig.add_subplot(111)

        ecols,ewids, estyles, ealphas =\
            [{} for i in range(4)]
        edges = []
        tf_doublet_terms = [(k, v)
                            for k, v in modules.iteritems()
                            if tf in k and len(set(k)) == 2 ] 

        tf_triplet_terms = [(k, v)
                            for k, v in modules.iteritems()
                            if tf in k and len(set(k)) == 3 ] 
        
        isdub = dict([(k,0.) for k in tfnames])
        istrip =dict([(k,0.) for k in tfnames])

        coflens = [len(e[1]['coefs']) 
                   for e in  tf_triplet_terms + tf_doublet_terms]
        max_l = max(coflens)

        for tt,v in sorted(tf_triplet_terms,
                         key = lambda x: len(x[1]['genes']))[::-1][:max_rank]:
            partners =tuple( [t for t in set(tt) if not t==tf])
            for p in partners: istrip[p] = 1 #istrip[[tfnodes.index(p) for p in partners]] = 1
            edge = partners 
            ecols[edge] = colorfun(modules[tt]['coefs'])
            ewids[edge] = widthfun(modules[tt]['coefs'],max_l)
            ealphas[edge] = 1 if module_type in ['triples','all'] \
                else .1
            estyles[edge] = 'dotted'
            
            edges.append(edge)
            
                        
        for td,v in sorted(tf_doublet_terms,
                         key = lambda x: len(x[1]['genes']))[::-1][:max_rank]:
            partners = tuple([t for t in set(td) if not t==tf])
            for p in partners: isdub[p] = 1
            edge = (tuple([tf] + list(partners))) 
            ecols[edge] = colorfun(modules[td]['coefs'])
            ewids[edge] = widthfun(modules[td]['coefs'], max_l)
            ealphas[edge] = 1 if module_type in ['doubles','all'] \
                else .1

            estyles[edge] = 'solid'

            edges.append(edge)
        
        
        tf_graph = nx.DiGraph()
        tf_graph.add_nodes_from(tfnodes)
        tf_graph.add_edges_from(edges)
        
        ckw = dict([ (k, dict(color = ecols[k], #array([1,0,0])*isdub[k] +\
                                 # array([0,0,1])*istrip[k],
                              alpha = ealphas[k],
                              linestyle = estyles[k],
                              linewidth = ewids[k],
                              arrowstyle = '-'))
                       for k in tf_graph.edges()])        
        circlepoints = dict([ (k, dict(facecolor ='white', #array([1,0,0])*isdub[k] +\
                                       # array([0,0,1])*istrip[k],
                                       alpha = round(ealphas[k],2),
                                       edgecolor = ecols[k],
                                       linestyle = estyles[k],
                                       linewidth = 3,))
                       for k in tf_graph.edges()])


        ax.set_title('Top modules for TF: {0}'.format(tf))
        myplots.padded_limits(ax,  *zip(*pos.values()))
        nodes = tf_graph.nodes()
        gd.draw(tf_graph,pos,edges, 
                scatter_nodes =tf_graph.nodes(),
                skw = {'s':[200 if n == tf else 2
                            for n in nodes],
                       'facecolor':[colorfun(modules[tuple([n])]['coefs']) 
                                if n == tf
                                else 'black' 
                                for n in nodes],
                       'edgecolor':'black',
                       'linewidth':2,
                       'alpha':1},#[1 if n == tf else .1 for n in nodes]},
                ckw = {})
        colors,alphas,tripoints = [{} for i in range(3)]
        for e in edges:
            colors[e] = 'black'
            alphas[e] = .2
            tripoints[e] = array(pos[tf])
        
        plot_tri = False
        plot_circ = True
        if plot_tri:
            gd.overlay(tf_graph, pos, edges,
                       tripoints = tripoints,
                       colors = colors,
                       alphas = alphas)
        if plot_circ:
            gd.overlay(tf_graph, pos, edges,
                       circlepoints =circlepoints)

        ax2 = fig.add_axes([.05,.05,.2,.2])
        ax2.set_xticks([])
        ax2.set_yticks([])
        coefs = modules[tuple([tf])]['coefs']
        l = len(coefs)
        sx = sy = ceil(sqrt(l))
        xs = mod(arange(l), sx)
        ys = floor( arange(l) / sx)
        cs = [colorfun([c/2]) for c in sorted(coefs)]
        ss = 100

        ax2.scatter(xs, ys, s = ss, color = cs)

        fig.savefig(figtemplate.format('tf_{0}_net_rank{1}_{2}'.\
                                           format(tf,max_rank,module_type)))