def distinctDrugBoxplots_PHENOSCORE(who_ps, res, ctrl_points, folder='/media/lalil0u/New/projects/drug_screen/results/'):
    '''
    Here we're looking at phenotypic scores, compared with controls
'''
    cm = ColorMap()
    cr = cm.makeColorRamp(256, ["#FFFF00", "#FF0000"])
    degrade = [cm.getColorFromMap(x, cr, 0, 10) for x in range(11)]

    f=open('/media/lalil0u/New/projects/drug_screen/results/well_drug_dose.pkl')
    _, drugs, _, doses_cont, _=pickle.load(f)
    f.close()
    
    exposure=[]
    doses=[]

    for exp in who_ps:
        exposure.append(drugs["{}--{:>05}".format(exp.split('--')[0], int(exp.split('--')[1]))])
        doses.append(doses_cont["{}--{:>05}".format(exp.split('--')[0], int(exp.split('--')[1]))])
    exposure=np.array(exposure); doses=np.array(doses)
    for drug in DRUGS:
        f,axes=p.subplots(4,4, sharex=True, figsize=(24,12))
        for k,class_ in enumerate(CLASSES):
            axes.flatten()[k].boxplot(ctrl_points[:,k])
            for dose in range(10):
                where_=np.where((exposure==drug)&(doses==dose))[0]
                x=np.random.normal(1, 0.05, size=where_.shape[0])
                if where_.shape[0]>0: 
                    axes.flatten()[k].scatter(x, res[where_,k], color=degrade[dose], alpha=0.8, s=6)
            
            axes.flatten()[k].set_title(class_)
        
        p.title(drug)
        p.savefig(os.path.join(folder, 'phenoscore_{}.png'.format(drug)))
        
    p.close('all')
def distinctDrugBoxplots_PERC(who, exposure,doses, perc, phenotypes):
    '''
    We suppose that plate 4 is already removed from r and who
    Here we're directly looking at phenotypes percentages over the whole movie, compared with controls and Mitocheck hits
    Percentages are in the file /media/lalil0u/New/projects/drug_screen/results/all_Mitocheck_DS_phenohit.pkl
'''
    cm = ColorMap()
    cr = cm.makeColorRamp(256, ["#FFFF00", "#FF0000"])
    degrade = [cm.getColorFromMap(x, cr, 0, 10) for x in range(11)]
    
    f,axes=p.subplots(4,9, sharex=True, sharey=True)
    
    for i, pheno in enumerate(phenotypes):
        axes.flatten()[i].boxplot([perc[np.where(exposure==pheno),k] for k in range(perc.shape[1])])
        axes.flatten()[i].set_title(pheno)
        axes.flatten()[i].set_ylim(-0.05,0.9)
        
    dd=np.zeros(shape=lim_Mito); dd.fill(-1)
    doses=np.hstack((dd, doses))
    
    for j, drug in enumerate(DRUGS):
        for dose in range(10):
            where_=np.where((exposure==drug)&(doses==dose))[0]
            if where_.shape[0]>0: 
                for k in range(perc.shape[1]):
                    axes.flatten()[8+j].scatter([k+1 for x in range(where_.shape[0])], perc[where_,k], color=degrade[dose], alpha=0.5, s=5)
        axes.flatten()[8+j].set_title(drug)
    print pheno
    where_=np.where(exposure=='empty')[0]
    axes.flatten()[8+j+1].boxplot([perc[where_,k] for k in range(perc.shape[1])])
    axes.flatten()[8+j+1].set_title('Control')
    axes.flatten()[8+j+1].set_xticklabels(CLASSES, rotation='vertical')
    p.show()
def distinctDrug_PHENOSCORE(who_ps, res, folder='/media/lalil0u/New/projects/drug_screen/results/'):
    '''
    Here we're looking at phenotypic scores, compared with controls
'''
    
    norm = mpl.colors.Normalize(-0.2,0.6)
    
    cm = ColorMap()
    cr = cm.makeColorRamp(256, ["#FFFF00", "#FF0000"])
    degrade = [cm.getColorFromMap(x, cr, 0, 10) for x in range(11)]

    f=open('/media/lalil0u/New/projects/drug_screen/results/well_drug_dose.pkl')
    _, drugs, _, doses_cont, _=pickle.load(f)
    f.close()
    
    exposure=[]
    doses=[]
    plates=np.array([el.split('--')[0] for el in who_ps])

    for exp in who_ps:
        exposure.append(drugs["{}--{:>05}".format(exp.split('--')[0], int(exp.split('--')[1]))])
        doses.append(doses_cont["{}--{:>05}".format(exp.split('--')[0], int(exp.split('--')[1]))])
    exposure=np.array(exposure); doses=np.array(doses)
    for drug in DRUGS:
        f,axes=p.subplots(1,3, figsize=(24,12))
        for i in range(3):
            currPl='LT0900_0{}'.format(i+1)
            range_dose=sorted(list(set(doses[np.where((plates==currPl)&(exposure==drug))])))
            currM=np.array([[res[np.where((plates==currPl)&(exposure==drug)&(doses==dose))[0], k] for dose in range_dose] for k in range(len(CLASSES))])[:,:,0]
            print currM.shape, drug
            axes.flatten()[i].matshow(currM, cmap=mpl.cm.YlOrRd, norm=norm)
            axes.flatten()[i].set_title(currPl)
            axes.flatten()[i].set_xticks(range(11))
            axes.flatten()[i].set_yticks(range(15))
            axes.flatten()[i].set_yticklabels(CLASSES)
        
        f.suptitle(drug)
        p.savefig(os.path.join(folder, 'phenoscore_nice_{}.png'.format(drug)))
        
    p.close('all')
    import matplotlib as mpl
    mpl.use('Agg')
    
elif getpass.getuser()=='lalil0u':
    locfit = objects.packages.importr('locfit')
    
import matplotlib.pyplot as p
globalenv = objects.globalenv

from tracking.histograms.summaries_script import progFolder, scriptFolder, pbsArrayEnvVar, pbsErrDir, pbsOutDir, path_command

from util.plots import couleurs, markers, plotBarPlot
from util.make_movies_mito_cbio import ColorMap
from analyzer import CONTROLS, quality_control, plates, xbL, compoundL

cm=ColorMap()
cr = cm.makeColorRamp(N=10, basic_colors=["#FFFF00", "#FF0000"])
cr.insert(0,(0,0,0))
cr={i:cr[i] for i in range(len(cr))}
cr[15]=(1,0,1)

def plotRegularizedResults(localRegMeasure=False, 
                loadingFolder='/media/lalil0u/New/projects/Xb_screen/dry_lab_results/MITOSIS/phenotype_analysis_up_down', 
                filename='phenoAnalysis_plateNorm_', 
                pheno_list = ['Anaphase_ch1', 'Apoptosis_ch1', 'Folded_ch1', 'Interphase_ch1', 'Metaphase_ch1',
                              'Prometaphase_ch1', 'WMicronuclei_ch1', 'Frozen_ch1'],
                fig_filename="Difference_with_controls_{}.png"):
    
    file_list=sorted(filter(lambda x: filename in x, os.listdir(loadingFolder)))
    result={pheno:defaultdict(dict) for pheno in pheno_list}
    who=[]#for well list
def heatmap(x, row_header, column_header, row_method,
            column_method, row_metric, column_metric,
            color_gradient, 
            filename, 
            other_data=None, 
            log=False, trad=False, 
            level_row=0.4, level_column=0.5,
            folder=os.getcwd(),
            range_normalization=(-2,2), colorbar_ticks=[-2, 0, 2],
            colorbar_ticklabels=['$ <\mu-2 \sigma$', '$\mu$', '$> \mu+2 \sigma$'], colorbar_title='Feature range',
            title=None,
            save=False,
            show=True):
    
    print "\nPerforming hiearchical clustering using %s for columns and %s for rows" % (column_metric,row_metric),
    if numpy.any(numpy.isnan(x)):
        sys.stderr.write("WARNING, there are NaN values in the data. Hence distances with data elements that have NaN values will have value NaN, which might perturb the hierarchical clustering.")
        
    """
    This below code is based in large part on the protype methods:
    http://old.nabble.com/How-to-plot-heatmap-with-matplotlib--td32534593.html
    http://stackoverflow.com/questions/7664826/how-to-get-flat-clustering-corresponding-to-color-clusters-in-the-dendrogram-cre
    
    Possibilities for methods: single, complete, average, centroid, median, ward
    
    Possibilities for metrics: 'braycurtis', 'canberra', 'chebyshev', 'cityblock', 'correlation', 
    'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis', 'matching', 
    'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'yule

    x is an m by n ndarray, m observations, n genes, or m rows,n columns
    
    WARNING WARNING
    This is a modified version to work with "big data" (starting with m=50,000). Indeed, the previous version actually stores
    the distance matrix in the memory which makes it crash. Here, we use the package fastcluster (see http://danifold.net/fastcluster.html)
    in its memory-efficient implementation.
    The parameter method must be one of 'single', 'centroid', 'median', 'ward', complete, average, weighted.
    It can take a dissimilarity matrix in input, ie we don't necessarily have to use a metric which is already implemented
    
    If one wants to plot another data than that which is used for the clustering, then this can be inputed in other_data.
    If X is n_row, n_columns, then other_data should be n_row, m_col
    
    """
    print level_column, level_row
        
    #for export
    if numpy.any(~numpy.array([type(s)==str for s in row_header])):
        row_header=[str(el) for el in row_header]
    if numpy.any(~numpy.array([type(s)==str for s in column_header])):
        column_header=[str(el) for el in column_header]
        
    ### Define the color gradient to use based on the provided name
    n = len(x[0]); m = len(x)
    if color_gradient == 'red_white_blue':
        cmap=pylab.cm.bwr
    if color_gradient == 'red_black_sky':
        cmap=RedBlackSkyBlue()
    if color_gradient=='OrRd':
        cmap = pylab.cm.OrRd
    if color_gradient == 'red_black_blue':
        cmap=RedBlackBlue()
    if color_gradient == 'red_black_green':
        cmap=RedBlackGreen()
    if color_gradient == 'yellow_black_blue':
        cmap=YellowBlackBlue()
    if color_gradient == 'seismic':
        cmap=pylab.cm.seismic
    if color_gradient == 'green_white_purple':
        cmap=pylab.cm.PiYG_r
    if color_gradient == 'coolwarm':
        cmap=pylab.cm.coolwarm
    if color_gradient=='YlOrRd':
        cmap=pylab.cm.YlOrRd

    ### Scale the max and min colors so that 0 is white/black
    vmin=numpy.nanmin(x)
    vmax=numpy.nanmax(x)
    vmax = max([vmax,abs(vmin)])
    #vmin = vmax*-1
#    if log:
#        norm = mpl.colors.LogNorm(vmin, vmax) ### adjust the max and min to scale these colors
#    elif normalization:
#        norm = mpl.colors.Normalize(10**(-70), 1)
#    else:
    if numpy.any(x<0):
        norm = mpl.colors.Normalize(range_normalization[0], range_normalization[1])
    else:
        if range_normalization[0]<0:
            norm = mpl.colors.Normalize(0,range_normalization[1])
        else:
            norm = mpl.colors.Normalize(range_normalization[0], range_normalization[1])
    ### Scale the Matplotlib window size
    default_window_hight = 8.5
    default_window_width = 12
    fig = pylab.figure(figsize=(default_window_width,default_window_hight)) ### could use m,n to scale here
    color_bar_w = 0.015 ### Sufficient size to show
        
    ## calculate positions for all elements
    # ax1, placement of dendrogram 1, on the left of the heatmap
    #if row_method != None: w1 = 
    [ax1_x, ax1_y, ax1_w, ax1_h] = [0.05,0.22,0.2,0.6]   ### The second value controls the position of the matrix relative to the bottom of the view
    width_between_ax1_axr = 0.004
    height_between_ax1_axc = 0.004 ### distance between the top color bar axis and the matrix
    
    # axr, placement of row side colorbar
    [axr_x, axr_y, axr_w, axr_h] = [0.31,0.1,color_bar_w,0.6] ### second to last controls the width of the side color bar - 0.015 when showing
    axr_x = ax1_x + ax1_w + width_between_ax1_axr
    axr_y = ax1_y; axr_h = ax1_h
    width_between_axr_axm = 0.004

    # axc, placement of column side colorbar
    [axc_x, axc_y, axc_w, axc_h] = [0.4,0.63,0.5,color_bar_w] ### last one controls the hight of the top color bar - 0.015 when showing
    axc_x = axr_x + axr_w + width_between_axr_axm
    axc_y = ax1_y + ax1_h + height_between_ax1_axc
    height_between_axc_ax2 = 0.004

    # axm, placement of heatmap for the data matrix
    [axm_x, axm_y, axm_w, axm_h] = [0.4,0.9,2.5,0.5]
    axm_x = axr_x + axr_w + width_between_axr_axm
    axm_y = ax1_y; axm_h = ax1_h
    axm_w = axc_w

    # ax2, placement of dendrogram 2, on the top of the heatmap
    [ax2_x, ax2_y, ax2_w, ax2_h] = [0.3,0.72,0.6,0.15] ### last one controls hight of the dendrogram
    ax2_x = axr_x + axr_w + width_between_axr_axm
    ax2_y = ax1_y + ax1_h + height_between_ax1_axc + axc_h + height_between_axc_ax2
    ax2_w = axc_w

    # axcb - placement of the color legend
    [axcb_x, axcb_y, axcb_w, axcb_h] = [0.07,0.88,0.18,0.04]

    # Compute and plot top dendrogram
    if column_method != None:
        start_time = time.time()
#        d2 = dist.pdist(x.T)
#        D2 = dist.squareform(d2)
        ax2 = fig.add_axes([ax2_x, ax2_y, ax2_w, ax2_h], frame_on=True)
        
        Y2 = fastcluster.linkage_vector(x.T, method=column_method, metric=column_metric) ### array-clustering metric - 'average', 'single', 'centroid', 'complete'
        Z2 = sch.dendrogram(Y2)
        ind2 = sch.fcluster(Y2,level_column*max(Y2[:,2]),'distance') ### This is the default behavior of dendrogram
        ax2.set_xticks([]) ### Hides ticks
        ax2.set_yticks([])
        time_diff = str(round(time.time()-start_time,1))
        print 'Column clustering completed in %s seconds' % time_diff
    else:
        ind2 = ['NA']*len(column_header) ### Used for exporting the flat cluster data
        
    # Compute and plot left dendrogram.
    if row_method != None:
        start_time = time.time()
#        d1 = dist.pdist(x)
#        D1 = dist.squareform(d1)  # full matrix
        ax1 = fig.add_axes([ax1_x, ax1_y, ax1_w, ax1_h], frame_on=True) # frame_on may be False
        if row_metric==None:
            Y1 = fastcluster.linkage_vector(x, method=row_method) ### gene-clustering metric - 'average', 'single', 'centroid', 'complete'
        else:
            Y1 = fastcluster.linkage_vector(x, method=row_method, metric=row_metric) ### gene-clustering metric - 'average', 'single', 'centroid', 'complete'
        Z1 = sch.dendrogram(Y1, orientation='right')
        ind1 = sch.fcluster(Y1,level_row*max(Y1[:,2]),'distance') ### This is the default behavior of dendrogram
        ax1.set_xticks([]) ### Hides ticks
        ax1.set_yticks([])
        time_diff = str(round(time.time()-start_time,1))
        print 'Row clustering completed in %s seconds' % time_diff
    else:
        ind1 = ['NA']*len(row_header) ### Used for exporting the flat cluster data
    if save:
        print 'Saving flat clusters in', 'Flat_clusters_{}_{}.pkl'.format(filename, level_row) 
        f=open('Flat_clusters_{}_{}.pkl'.format(filename,level_row), 'w')
        pickle.dump([ind1, ind2],f); f.close()
        
    ind1_to_return = np.array(ind1)
    
#     if trad:
#         if len(row_header)>100:
#             genes=list(row_header)
#             clustering = numpy.array(ind1)
#         elif len(column_header)>100:
#             genes=list(column_header)
#             clustering=numpy.array(ind2)
#         else:
#             print 'Tell which of column and row is the gene list'
#             pdb.set_trace()
#         #il faut d'abord traduire de SYMBOL en ENSEMBL
#         trad = EnsemblEntrezTrad('../data/mapping_2014/mitocheck_siRNAs_target_genes_Ens75.txt')
#         trad['ctrl']='None'
#         
#         result=[Counter([trad[genes[k]] for k in numpy.where(clustering==cluster)[0]]).keys() for cluster in range(1,numpy.max(clustering)+1)]
#         for geneList in result:
#             for i,gene in enumerate(geneList):
#                 if '/' in gene:
#                     geneList[i]=gene.split('/')[0]
#                     geneList.append(gene.split('/')[1])
#                 
#         #ensuite on va enregistrer les genes des differents clusters dans differents fichiers
#         #background par defaut c'est genes_list.txt
#         print "Nb of cluster found", numpy.max(clustering)
#         multipleGeneListsToFile(result, ['Cluster {}'.format(k+1) for k in range(numpy.max(clustering))], 'gene_cluster_{}_{}.txt'.format(column_method, filename))
    
    # Plot distance matrix.
    axm = fig.add_axes([axm_x, axm_y, axm_w, axm_h])  # axes for the data matrix
    xt = x
    if column_method != None:
        idx2 = Z2['leaves'] ### apply the clustering for the array-dendrograms to the actual matrix data
        xt = xt[:,idx2]
        ind2 = ind2[idx2] ### reorder the flat cluster to match the order of the leaves the dendrogram
    if row_method != None:
        idx1 = Z1['leaves'] ### apply the clustering for the gene-dendrograms to the actual matrix data
        xt = xt[idx1,:]   # xt is transformed x
        if other_data is not None:
            other_data=other_data[idx1,:]
        
        ind1 = ind1[idx1] ### reorder the flat cluster to match the order of the leaves the dendrogram
    ### taken from http://stackoverflow.com/questions/2982929/plotting-results-of-hierarchical-clustering-ontop-of-a-matrix-of-data-in-python/3011894#3011894
    if other_data is None:
        im = axm.matshow(xt, aspect='auto', origin='lower', cmap=cmap, norm=norm) ### norm=norm added to scale coloring of expression with zero = white or black
    else:
        im = axm.matshow(other_data, aspect='auto', origin='lower', cmap=cmap, norm=norm) ### norm=norm added to scale coloring of expression with zero = white or black
    axm.set_xticks([]) ### Hides x-ticks
    axm.set_yticks([])

    # Add text
    new_row_header=[]
    new_column_header=[]
    for i in range(x.shape[0]):
        if row_method != None:
            if len(row_header)<200: ### Don't visualize gene associations when more than 100 rows
                axm.text(x.shape[1]-0.5, i, '  {}'.format(row_header[idx1[i]]), fontsize=6)
            new_row_header.append(row_header[idx1[i]])
        else:
            if len(row_header)<200: ### Don't visualize gene associations when more than 100 rows
                axm.text(x.shape[1]-0.5, i, ' {}'.format(row_header[i]), fontsize=6) ### When not clustering rows
            new_row_header.append(row_header[i])
            
    column_decider=x if other_data is None else other_data
    for i in range(column_decider.shape[1]):
        if column_method != None:
            if len(column_header)<200:
                axm.text(i, -0.9, '{}'.format(column_header[idx2[i]]), rotation=270, verticalalignment="top", fontsize=6) # rotation could also be degrees
            new_column_header.append(column_header[idx2[i]])
        else: ### When not clustering columns
            if len(column_header)<200:
                axm.text(i, -0.9, '{}'.format(column_header[i]), rotation=270, verticalalignment="top", fontsize=6)
            new_column_header.append(column_header[i])

    # Plot colside colors
    # axc --> axes for column side colorbar
    if column_method != None:
        print 'Number of clusters for columns ', np.bincount(ind2)
        axc = fig.add_axes([axc_x, axc_y, axc_w, axc_h])  # axes for column side colorbar
        #getting a degrade colormap for the column side colorbar
        cm = ColorMap()
        cr = cm.makeColorRamp(256, ["#FFFF00", "#FF0000"])
        degrade = [cm.getColorFromMap(x, cr, 0, 10) for x in range(len(np.bincount(ind2)))]
        cmap_c = mpl.colors.ListedColormap(degrade)
        
        dc = numpy.array(ind2, dtype=int)
        dc.shape = (1,len(ind2)) 
        im_c = axc.matshow(dc, aspect='auto', origin='lower', cmap=cmap_c)
        axc.set_xticks([]) ### Hides ticks
        axc.set_yticks([])
    
    # Plot rowside colors
    # axr --> axes for row side colorbar
    if row_method != None:
        print 'Number of clusters for rows ', np.bincount(ind1)
        axr = fig.add_axes([axr_x, axr_y, axr_w, axr_h])  # axes for column side colorbar
        dr = numpy.array(ind1, dtype=int)
        dr.shape = (len(ind1),1)
#rainbow colormap for row side colorbar
        cmap_r = mpl.cm.gist_rainbow
        
        im_r = axr.matshow(dr, aspect='auto', origin='lower', cmap=cmap_r)
        axr.set_xticks([]) ### Hides ticks
        axr.set_yticks([])

    # Plot color legend
    axcb = fig.add_axes([axcb_x, axcb_y, axcb_w, axcb_h], frame_on=False)  # axes for colorbar
    axcb.set_title(colorbar_title, fontsize=15)
    cb = mpl.colorbar.ColorbarBase(axcb, cmap=cmap,norm=norm, orientation='horizontal',
                                   ticks=colorbar_ticks)
    cb.ax.set_xticklabels(colorbar_ticklabels, fontsize=15)
    
    filename = '%s/Clust_%s_%s_%s.pdf' % (folder, filename[:10],column_method,row_method)
#    exportFlatClusterData(filename, new_row_header,new_column_header,xt,ind1,ind2)

#    ### Render the graphic
#    if len(row_header)>50 or len(column_header)>50:
#        pylab.rcParams['font.size'] = 5
#    else:
    pylab.rcParams['font.size'] = 15
    if title is not None:
        axm.set_xlabel(title)
    pylab.savefig(filename)
    print 'Exporting:',filename
    if show:
        pylab.show()
#     if trad:
#         return result
    return ind1_to_return