# no expression change
noExprChange = np.loadtxt(options.b, dtype=bool)


# normalize by mean in no expression change
# readCountNorm = readCount/np.sum(readCount[noExprChange], 0)*np.mean(np.sum(readCount[noExprChange], 0))
readCountNorm = readCount/np.mean(readCount[noExprChange], 0)*10    # such that average count in TSS peaks is 10. 

# save file of normalized peakCounts
if not options.o:
    outputfile = os.path.splitext(options.c)[0]
else: outputfile = options.o
np.savetxt(outputfile+'.normalize.peakCount', readCountNorm)

# also save file with no replicates
readCountRed = filefun.reduceByReplicates(readCount, parameters.indices_for_reduce_by_replicates)
readCountNorm = readCountRed/np.mean(readCountRed[noExprChange], 0)*10
np.savetxt(outputfile+'.normalize.norepicates.peakCount', readCountNorm)

# plot and save correlation heat map
numSamples = readCount.shape[1]
readCountNorm = readCount/np.mean(readCount[noExprChange], 0)*10 
distanceCorr = np.array([[getDistanceSpearmanr(readCountNorm[:, i], readCountNorm[:, j]) for i in range(numSamples)] for j in range(numSamples)])
plotHeatMap(distanceCorr, rowlabels=parameters.headers, columnlabels=parameters.headers, fontSize=11, cmap='RdGy_r', vmin=0, vmax=1)
plt.savefig(outputfile+'.correlation_heatmap.all.pdf')

distanceCorr = np.array([[getDistanceSpearmanr(readCountNorm[noExprChange, i], readCountNorm[noExprChange, j]) for i in range(numSamples)] for j in range(numSamples)])
plotHeatMap(distanceCorr, rowlabels=parameters.headers, columnlabels=parameters.headers, fontSize=11, cmap='RdGy_r', vmin=0, vmax=1)
plt.savefig(outputfile+'.correlation_heatmap.no_change.pdf')

# plot only old samples with replicates
# no expression change
noExprChange = np.loadtxt(options.b, dtype=bool)


# normalize by mean in no expression change
# readCountNorm = readCount/np.sum(readCount[noExprChange], 0)*np.mean(np.sum(readCount[noExprChange], 0))
readCountNorm = readCount/np.mean(readCount[noExprChange], 0)*10    # such that average count in TSS peaks is 10. 

# save file of normalized peakCounts
if not options.o:
    outputfile = os.path.splitext(options.c)[0]
else: outputfile = options.o
np.savetxt(outputfile+'.normalize.peakCount', readCountNorm)

# also save file with no replicates
readCountRed = filefun.reduceByReplicates(readCount, np.array([[0, 1], [2, 3]]))
np.savetxt(outputfile+'.normalize.norepicates.peakCount', readCountRed/np.mean(readCountRed[noExprChange], 0)*10)

# plot and save correlation heat map
numSamples = readCount.shape[1]
distanceCorr = np.array([[getDistanceSpearmanr(readCountNorm[:, i], readCountNorm[:, j]) for i in range(numSamples)] for j in range(numSamples)])
plotHeatMap(distanceCorr, rowlabels=parameters.headers, columnlabels=parameters.headers_human, fontSize=11, cmap='RdGy_r', vmin=0, vmax=1)
plt.savefig(outputfile+'.correlation_heatmap.all.pdf')

"""
Reminder of script is making a lot of plots. Requires calls of significant versus not peaks
optional: run python script 'find_significant_peaks.py'

os.system('python %s -b %s -p %s --indx %s'%('scoring/140815_peaks.coverageCorr.all.bed', outputfile+'.normalize.replicate_red.peakCount', options.b))
"""