Пример #1
1
def category_bar_charts():
    sns.set(style="whitegrid")
    sns.set_context("notebook", font_scale=1.5, rc={"lines.linewidth": 2.5})
    f, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 7), sharex=False)

    retirement_sums = sorted(retirement_class_sums(), reverse=True)
    x = [a[1] for a in retirement_sums]
    y = [a[0] for a in retirement_sums]
    sns.barplot(x, y, palette="YlOrRd_d", ax=ax1)
    ax1.set_ylabel("By Category")

    account_type_sums = sorted(account_by_type_sums(), reverse=True)
    x = [a[1] for a in account_type_sums]
    y = [a[0] for a in account_type_sums]
    sns.barplot(x, y, palette="BuGn_d", ax=ax2)
    ax2.set_ylabel("By Category")

    account_owner_sums = sorted(account_by_owner_sums(), reverse=True)
    x = [a[1] for a in account_owner_sums]
    y = [a[0] for a in account_owner_sums]
    sns.barplot(x, y, palette="Blues_d", ax=ax3)
    ax3.set_ylabel("By Owner")

    sns.despine(left=True)
    f.tight_layout()
    canvas = FigureCanvas(plt.gcf())
    png_output = io.BytesIO()
    canvas.print_png(png_output)
    response=make_response(png_output.getvalue())
    response.headers['Content-Type'] = 'image/png'
    return response
Пример #2
0
def plot_hist_algo(wave_hist_algor, pulse_hist_algor, multi_wave_hist_algor):
    inch_factor = 2.54
    sns.set_context("poster")
    sns.axes_style('white')
    # sns.set_style("ticks")

    fig4= plt.figure(figsize=(35./ inch_factor, 20./ inch_factor))
    ax1 = fig4.add_subplot(2, 3, (1, 4))
    dafr = pd.DataFrame([wave_hist_algor, multi_wave_hist_algor, pulse_hist_algor]) #turn
    dafr = dafr.transpose()
    dafr.columns = ['wave', 'multi-wave', 'pulse']
    sns.violinplot(data=dafr,  ax=ax1, col=("blue", "green", "red"))
    ax1.set_ylabel('psd_proportion')
    ax1.set_xlabel('EOD-type')
    ax1.set_title('Fishsorting based on PSD')

    wave_psd_data = np.load('wave_psd_data.npy')
    wave_hist_data = wave_psd_data[1][:len(wave_psd_data[0][wave_psd_data[0]<1500])]
    ax3 = fig4.add_subplot(2, 3, (2, 5))
    n, bin, patch = ax3.hist(wave_hist_data, 50, color='blue', alpha=0.7, normed=True)
    # ax3.set_ylim([0, max(n)+10])
    ax3.set_ylabel('counts in histogram bin')
    ax3.set_xlabel('amplitude of PSD')
    ax3.set_title('Histogram of pulsefish PSD')

    pulse_psd_data = np.load('pulse_psd_data.npy')
    pulse_hist_data = pulse_psd_data[1][:len(pulse_psd_data[0][pulse_psd_data[0]<1500])]
    ax2 = fig4.add_subplot(2, 3, (3, 6))
    ax2.hist(pulse_hist_data, 50, color='red', alpha=0.7, normed=True)
    # ax2.set_ylim([0, max(n)+10])
    ax2.set_ylabel('counts in histogram bin')
    ax2.set_xlabel('amplitude of PSD')
    ax2.set_title('Histogram of pulsefish PSD')

    fig4.tight_layout()
Пример #3
0
def TuningResponseArea(tuningCurves, unitKey='', figPath=[]):
	""" Plot the tuning response area for tuning curve data.        
	:param tuningCurves: pandas.DataFrame from spreadsheet with experimental data loaded from Excel file
	:type tuningCurves: pandas.core.DataFrame
	:param unitKey: identifying string for data, possibly unit name/number and test number 
	:type unitKey: str
	:param figPath: Directory location for plots to be saved
	:type figPath: str
	"""
	f = plt.figure()
	colorRange = (-10,10.1)
	I = np.unique(np.array(tuningCurves['intensity']))
	F = np.array(tuningCurves['freq'])
	R = np.array(np.zeros((len(I), len(F))))
	for ci, i in enumerate(I):
		for cf, f in enumerate(F):
			R[ci,cf] = tuningCurves['response'].where(tuningCurves['intensity']==i).where(tuningCurves['freq']==f).dropna().values[0]
	levelRange = np.arange(colorRange[0], colorRange[1], (colorRange[1]-colorRange[0])/float(25*(colorRange[1]-colorRange[0]))) 
	sns.set_context(rc={"figure.figsize": (7, 4)})
	ax = plt.contourf(F, I, R)#, vmin=colorRange[0], vmax=colorRange[1], levels=levelRange, cmap = cm.bwr )
	plt.colorbar()
	# plt.title(unit, fontsize=14)
	plt.xlabel('Frequency (kHz)', fontsize=14)
	plt.ylabel('Intensity (dB)', fontsize=14)
	if len(figPath)>0: 
		plt.savefig(figPath + 'tuningArea_' + unitKey +'.png')
Пример #4
0
    def plot_results(self):
        """
        A simple script to plot the balance of the portfolio, or
        "equity curve", as a function of time.
        """
        sns.set_palette("deep", desat=.6)
        sns.set_context(rc={"figure.figsize": (8, 4)})

        # Plot two charts: Equity curve, period returns
        fig = plt.figure()
        fig.patch.set_facecolor('white')

        df = pd.DataFrame()
        df["equity"] = pd.Series(self.equity, index=self.timeseries)
        df["equity_returns"] = pd.Series(self.equity_returns, index=self.timeseries)
        df["drawdowns"] = pd.Series(self.drawdowns, index=self.timeseries)

        # Plot the equity curve
        ax1 = fig.add_subplot(311, ylabel='Equity Value')
        df["equity"].plot(ax=ax1, color=sns.color_palette()[0])

        # Plot the returns
        ax2 = fig.add_subplot(312, ylabel='Equity Returns')
        df['equity_returns'].plot(ax=ax2, color=sns.color_palette()[1])

        # drawdown, max_dd, dd_duration = self.create_drawdowns(df["Equity"])
        ax3 = fig.add_subplot(313, ylabel='Drawdowns')
        df['drawdowns'].plot(ax=ax3, color=sns.color_palette()[2])

        # Rotate dates
        fig.autofmt_xdate()

        # Plot the figure
        plt.show()
Пример #5
0
 def addRandomThemes(self):
     themes = ["darkgrid", "whitegrid", "dark", "white", "ticks"]
     contexts = ["paper", "notebook", "talk", "poster"]
     theme = self.numFig % 5
     context = theme % 4
     sns.set_style(themes[theme])
     sns.set_context(contexts[context])
Пример #6
0
    def plot_results(self):
        """
        A simple script to plot the balance of the portfolio, or
        "equity curve", as a function of time.
        It requires OUTPUT_RESULTS_DIR to be set in the project
        settings.
        """
        sns.set_palette("deep", desat=0.6)
        sns.set_context(rc={"figure.figsize": (8, 4)})

        equity_file = os.path.join(settings.OUTPUT_DIR, "output.csv")
        equity = pd.io.parsers.read_csv(equity_file, parse_dates=True, header=0, index_col=0)

        # Plot three charts: Equity curve, period returns, drawdowns
        fig = plt.figure()
        fig.patch.set_facecolor("white")  # Set the outer colour to white

        # Plot the equity curve
        ax1 = fig.add_subplot(311, ylabel="Portfolio value")
        equity["Equity"].plot(ax=ax1, color=sns.color_palette()[0])

        # Plot the returns
        ax2 = fig.add_subplot(312, ylabel="Period returns")
        equity["Returns"].plot(ax=ax2, color=sns.color_palette()[1])

        # Plot the returns
        ax3 = fig.add_subplot(313, ylabel="Drawdowns")
        equity["Drawdown"].plot(ax=ax3, color=sns.color_palette()[2])

        # Plot the figure
        plt.show()
Пример #7
0
def UseSeaborn(palette='deep'):
    """Call to use seaborn plotting package
    """
    import seaborn as sns
    #No Background fill, legend font scale, frame on legend
    sns.set(style='whitegrid', font_scale=1.5, rc={'legend.frameon': True})
    #Mark ticks with border on all four sides (overrides 'whitegrid')
    sns.set_style('ticks')
    #ticks point in
    sns.set_style({"xtick.direction": "in","ytick.direction": "in"})

    # sns.choose_colorbrewer_palette('q')

    #Nice Blue,green,Red
    # sns.set_palette('colorblind')
    if palette == 'xkcd':
        #Nice blue, purple, green
        sns.set_palette(sns.xkcd_palette(xkcdcolors))
    else:
        sns.set_palette(palette)
    #Nice blue, green red
    # sns.set_palette('deep')

    # sns.set_palette('Accent_r')
    # sns.set_palette('Set2')
    # sns.set_palette('Spectral_r')
    # sns.set_palette('spectral')

    #FIX INVISIBLE MARKER BUG
    sns.set_context(rc={'lines.markeredgewidth': 0.1})
Пример #8
0
 def decode_uniform_samples_from_latent_space(_):
     fig, ax = plt.subplots()
     nx = ny = 20
     extent_x = extent_y = [-3, 3]
     extent = numpy.array(extent_x + extent_y)
     x_values = numpy.linspace(*(extent_x + [nx]))
     y_values = numpy.linspace(*(extent_y + [nx]))
     full_extent = extent * (nx + 1) / float(nx)
     canvas = numpy.empty((28 * ny, 28 * nx))
     for ii, yi in enumerate(x_values):
         for j, xi in enumerate(y_values):
             n = ii * nx + j + 1
             sys.stdout.write("\rsampling p(X|z), sample %d/%d" % (n, nx*ny))
             sys.stdout.flush()
             np_z = numpy.array([[xi, yi]])
             x_mean = sess.run(prior_model(latent=numpy.reshape(np_z, newshape=(1, LATENT_DIM)))[0])
             canvas[(nx - ii - 1) * 28:(nx - ii) * 28, j * 28:(j + 1) * 28] = x_mean[0].reshape(28, 28)
     with seaborn.axes_style('ticks'):
         seaborn.set_context(context='notebook', font_scale=1.75)
         fig, ax = plt.subplots(figsize=(12, 9))
     ax.imshow(canvas, extent=full_extent)
     ax.xaxis.set_ticks(numpy.linspace(*(extent_x + [nx])))
     ax.yaxis.set_ticks(numpy.linspace(*(extent_y + [ny])))
     ax.set_xlabel('z_1')
     ax.set_ylabel('z_2')
     ax.set_title('P(X|z); decoding latent space; (CONV, BNAE, IND_ERROR) = (%d,%d,%d)' % (CONV, BNAE, IND_ERROR))
     plt.show()
     plt.savefig(os.path.join(FLAGS.viz_dir, 'P(X|z).png'))
     return fig, ax
def plot_rolling_auto_home(df_attack=None,df_defence=None, window=5, nstd=1, 
                      detected_events_home=None,
                     detected_events_away=None, sky_events=None):
    
    sns.set_context("notebook", font_scale=1.8 ,rc={"lines.linewidth": 3.5, "figure.figsize":(18,12) })
    plt.subplots_adjust(bottom=0.85)
    mean = pd.rolling_mean(df_attack, center=True, window=window)
    std = pd.rolling_std(df_attack, center=True, window=window)
   
    detected_plot_extrema = df_attack.ix[argrelextrema(df_attack.values, np.greater)]

    df_filt_noise = df_attack[(df_attack > mean-std) & (df_attack < mean+std)]
    df_filt_noise = df_filt_noise.ix[detected_plot_extrema.index].dropna()

    df_filt_keep = df_attack[~((df_attack > mean-std) & (df_attack < mean+std))]
    df_filt_keep = df_filt_keep.ix[detected_plot_extrema.index].dropna()
    
    plt.plot(df_attack, color='#4CA64C', label='{} Attack'.format(all_matches[0]['home_team'].title()))
    plt.fill_between(df_attack.index, (mean-nstd*std), (mean+nstd*std), interpolate=False, alpha=0.4, color='#B2B2B2', label='$\mu + {} \\times \sigma$'.format(nstd))
    plt.scatter(df_filt_keep.index, df_filt_keep.values, marker='*', s=120, color='#000000', zorder=10, label='Selected maxima post-filtering')
    plt.scatter(df_filt_noise.index, df_filt_noise.values, marker='x', s=120, color='#000000', zorder=10, label='Unselected maxima post-filtering')
    
    df_defence.apply(lambda x: -1*x).plot(color='#000000', label='{} Defence'.format(all_matches[0]['home_team'].title()))
    
    if(len(detected_events_home) > 0):
        classifier_events_df_home= pd.DataFrame(detected_events_home)
        classifier_events_df_home[classifier_events_df_home.category == 'GOAL']
    if(len(detected_events_away) > 0):    
        classifier_events_df_away= pd.DataFrame(detected_events_away)
        classifier_events_df_away[classifier_events_df_away.category == 'GOAL']



    font0 = FontProperties(family='arial', weight='bold',style='italic', size=16)
    for i, row in classifier_events_df_home.iterrows():
        if row.category == 'OTHER':
            continue
        plt.text(row.event, df_attack.max(), "{} {} {}".format(all_matches[0]['home_team'].upper(), row.category, row.event), rotation='vertical', color='black', bbox=dict(facecolor='green', alpha=0.2))#, transform=transform)
    for i, row in classifier_events_df_away.iterrows():
        if row.category == 'OTHER':
            continue
        plt.text(row.event, (df_attack.max()), "{} {} {}".format(all_matches[0]['away_team'].upper(), row.category, row.event), rotation='vertical', color='black', bbox=dict(facecolor='red', alpha=0.2))
    
    high_peak_position = 0;
    if(df_attack.max() > df_defence.max()): high_peak_position = -(df_defence.max() * 2.0)
    else: high_peak_position = -(df_defence.max() * 1.25)
      
    # Functionality to include Sky Sports text commentary updates on plot for goal events.
#     for i, row in pd.DataFrame(sky_events).iterrows():
#         dedented_text = textwrap.dedent(row.text).strip()
#         plt.text(row.event, high_peak_position, "@SkySports {} AT {}:\n{}:\n{}".format(row.category, row.event.time(), row.title, textwrap.fill(dedented_text, width=40)), color='black', bbox=dict(facecolor='blue', alpha=0.2))
    
    plt.legend(loc=4)
    
    ax = plt.gca()
    label = ax.set_xlabel('time')
    plt.ylabel('Tweet frequency')
    plt.title('{} vs. {} (WK {}) - rolling averages window={} mins'.format(all_matches[0]['home_team'].title(), all_matches[0]['away_team'].title(), all_matches[0]['dbname'], window))
    plt.savefig('{}attack_{}_plain.pdf'.format(all_matches[0]['home_team'].upper(), all_matches[0]['away_team'].upper()))
    return detected_plot_extrema
def pltsns(style='ticks',context='talk'):
    global figdir
    sns.set_style(style)
    sns.set_style({'legend.frameon':True})
    sns.set_context(context)
    #figdir = datadir+'samoa/WATERSHED_ANALYSIS/GoodFigures/rawfigoutput/'
    return
def plot_diadem_score(neuron_distance_csv, outputDir,algorithms=None):

    df_nd = pd.read_csv(neuron_distance_csv)

    if not path.exists(outputDir):
        os.mkdir(outputDir)


    if algorithms is None:
       algorithms= order_algorithms_by_size(df_nd)
    ### all algorithm plot
    dfg = df_nd.groupby('algorithm')
    sample_size_per_algorithm=[]
    for alg in algorithms:
        sample_size_per_algorithm.append(dfg.get_group(alg).shape[0])



    #plot the average node distances
    plt.figure()
    sb.set_context("talk", font_scale=0.7)
    a=sb.barplot(y='algorithm', x='diadem_score', data=df_nd,order=algorithms)
    algorithm_name_mapping = rp.get_algorithm_name_dict()
    algorithm_names = [algorithm_name_mapping[x] for x in algorithms]
    a.set_yticklabels(['%s ($n$=%d )'%(algorithm_names[i], sample_size_per_algorithm[i]) for i in range(algorithms.size) ])
    #sb.set_context("talk", font_scale=3.0)
    #plt.xticks(rotation="90")
    plt.xlabel('Diadem Scores')
    plt.subplots_adjust(left=0.5,right=0.95,bottom=0.1, top=0.9)
    plt.savefig(outputDir + '/Diadem_score.png', format='png')
    #plt.show()
    plt.close()


    return
Пример #12
0
def plot_lai_comparison_thin(classification):
    sns.set_context(rc={'lines.linewidth': 0.8, 'lines.markersize': 6})
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(16, 8))
    palette = sns.color_palette("colorblind", 11)

    lai_c, plot_c, lai_m_c, lai_s_c = ld.find_lai_lp80(classification, ld.lp80)
    lai_h, plot_h, lai_m_h, lai_s_h = ld.find_lai_lp80(classification, ld.hemi)
    lai_t, plot_t, lai_m_t, lai_s_t = ld.find_lai_lp80(classification, ld.trap)

    lai_h_fill = np.ones(len(lai_c))*float('NaN')
    lai_h_fill[plot_c.searchsorted(plot_h)] = lai_h
    lai_t_fill = np.ones(len(lai_c))*float('NaN')
    lai_t_fill[plot_c.searchsorted(plot_t)] = lai_t

    xlist = np.arange(len(lai_c))
    ax.plot(xlist, lai_c, color=palette[0], label='Ceptometer')
    ax.plot(xlist, lai_h_fill, 'o', color=palette[1], label='Hemi. photos')
    ax.plot(xlist, lai_t_fill, 'o', color=palette[2], label='Litter traps')
    ax.axhline(lai_m_c, color=palette[0], linestyle='--')
    ax.axhline(lai_m_h, color=palette[1], linestyle='--')
    ax.axhline(lai_m_t, color=palette[2], linestyle='--')

    ax.set_xlabel('Number of LAI observations')
    ax.set_ylabel('LAI (m2/ m2)')
    plt.legend(loc='upper left')
    plt.ylim(0, 8)
    plt.xlim(0, xlist[-1]+1)
    return ax, fig
Пример #13
0
def showResults(challenger_data, model):
    ''' Show the original data, and the resulting logit-fit'''
    
    # First plot the original data
    plt.figure()
    sns.set_context('poster')
    sns.set_style('whitegrid')
    np.set_printoptions(precision=3, suppress=True)
    
    plt.scatter(challenger_data[:, 0], challenger_data[:, 1], s=75, color="k",
                alpha=0.5)
    plt.yticks([0, 1])
    plt.ylabel("Damage Incident?")
    plt.xlabel("Outside temperature (Fahrenheit)")
    plt.title("Defects of the Space Shuttle O-Rings vs temperature")
    plt.xlim(50, 85)
    
    # Plot the fit
    x = np.arange(50, 85)
    alpha = model.params[0]
    beta = model.params[1]
    y = logistic(x, beta, alpha)
    
    plt.hold(True)
    plt.plot(x,y,'r')
    outFile = 'ChallengerPlain.png'
    C2_8_mystyle.printout_plain(outFile, outDir='..\Images')
    plt.show()
Пример #14
0
def show_binomial():
    """Show an example of binomial distributions"""
    
    bd1 = stats.binom(20, 0.5)
    bd2 = stats.binom(20, 0.7)
    bd3 = stats.binom(40, 0.5)
    
    k = np.arange(40)
    
    sns.set_context('paper')
    sns.set_style('ticks')
    mystyle.set(14)
    
    markersize = 8
    plt.plot(k, bd1.pmf(k), 'o-b', ms=markersize)
    plt.hold(True)
    plt.plot(k, bd2.pmf(k), 'd-r', ms=markersize)
    plt.plot(k, bd3.pmf(k), 's-g', ms=markersize)
    plt.title('Binomial distribuition')
    plt.legend(['p=0.5 and n=20', 'p=0.7 and n=20', 'p=0.5 and n=40'])
    plt.xlabel('X')
    plt.ylabel('P(X)')
    sns.despine()
    
    mystyle.printout_plain('Binomial_distribution_pmf.png')
    
    plt.show()
	def PlotSTResponseEst(self, stResponseDF, label, duration=250, firstFreq=1):
		""" Plots response rate estimate for multiple frequencies and intensities as a contour plot.        
		:param stResponseDF: DataFrames results of Bayesian response analysis for multiple tone stimulus intensities
		:type stResponseDF: pandas DataFrame 
		:param label: Figure name
		:type label: str
		:param duration: Duration of recording window
		:type duration: float
		:param firstFreq: Set to skip first (spurious) entry 
		:type firstFreq: int
		:returns: Handle to plot
		"""		
		stResponseE = np.array(stResponseDF)
		freqs = np.array(stResponseDF.index.tolist())[1:].astype(np.float)
		sns.set_context(rc={"figure.figsize": (8, 4)})
		maxRes = np.max(abs(stResponseE[firstFreq:,:]))
		spontRate = np.average(stResponseE[firstFreq:,-1])
		ax = plt.imshow(stResponseE[firstFreq:,:], vmax=maxRes+spontRate, vmin=-maxRes+spontRate, extent=[0,duration,min(freqs),max(freqs)], aspect='auto', interpolation='nearest', origin='lower', cmap = cm.bwr)
		sns.despine()
		plt.grid(False)
		plt.title(label)
		plt.xlabel('Time (ms)')
		plt.ylabel('Frequency (kHz)')
		plt.colorbar()
		return ax
	def PlotBBNResponseCurve(self, bbnResponseProb, measure, unit=[], filePath=[], attn=False):
		""" Plots measure for multiple frequencies and intensities an a contour plot.        
		:param stResponseProb: DataFrames results of Bayesian response analysis for multiple tone stimulus intensities
		:type stResponseProb: pandas DataFrame 
		:param measure: Bayesian response analysis measure ['resProb', 'vocalResMag', 'vocalResMag_MLE', 'effectSize', 'effectSize_MLE', 'spontRate', 'spontRateSTD', 'responseLatency', 'responseLatencySTD', 'responseDuration']
		:type measure: integer [0-9]
		:param unit: Unique identifier for cell
		:type unit: str
		:param filePath: Path to directory where results will be saved
		:type filePath: str
		:returns: Handle to plot
		"""		
		measureName = ['resProb', 'vocalResMag', 'vocalResMag_MLE', 'effectSize', 'effectSize_MLE', 'spontRate', 'spontRateSTD', 'responseLatency', 'responseLatencySTD', 'responseDuration']
		tuningData = bbnResponseProb
		sns.set_palette(sns.color_palette("bright", 8))
		sns.set_context(rc={"figure.figsize": (5, 3)})
		sns.set_style("white")
		sns.set_style("ticks")
		if attn: ax = bbnResponseProb.loc[::-1,measure].fillna(0).plot(figsize=(6,4))
		else: ax = bbnResponseProb.loc[:,measure].fillna(0).plot(figsize=(6,4))
		sns.despine()
		plt.grid(False)
		plt.title(unit, fontsize=14)
		plt.xlabel('SPL (dB)', fontsize=12)
		plt.ylabel(measureName[measure], fontsize=12)
		plt.ylim(0.5,1.0)
# 		plt.gca().invert_xaxis()
		if len(filePath)>0:
			plt.savefig(self.dirPath + filePath + 'bbn_'+measureName[measure]+'_'+unit+'.pdf')        
			plt.close()
		else: plt.show()
		return ax
Пример #17
0
def fishtype_barplot(data):
    """ This function creates a bar plot showing the distribution of wave-fishes vs. pulse-fishes.

    :param data: dictionary with fish-type as keys and array of EODfs as values.
    """

    # Read the keys of the dictionary and use them to get the count of pulse- and wave-type fishes.
    keys = np.array(data.keys())
    bool_wave = np.array(['wave' in e for e in keys], dtype=bool)
    bool_pulse = np.array(['puls' in e for e in keys], dtype=bool)

    count_wave = len(data[keys[bool_wave][0]])
    count_pulse = len(data[keys[bool_pulse][0]])

    inch_factor = 2.54
    sns.set_context("poster")
    sns.axes_style('white')
    sns.set_style("ticks")
    fig, ax = plt.subplots(figsize=(10./inch_factor, 10./inch_factor))
    width = 0.5
    ax.bar(1-width/2., count_wave, width=width, facecolor='cornflowerblue', alpha=0.8)
    ax.bar(2-width/2., count_pulse, width=width, facecolor='salmon', alpha=0.8)
    ax.set_xticks([1, 2])
    ax.set_xticklabels(['Wave-type', 'Pulse-type'])
    ax.tick_params(axis='both', which='major', labelsize=14)
    ax.set_ylabel('Number of Fishes', fontsize=14)
    ax.set_title('Distribution of Fish-types', fontsize=16)
    sns.despine(fig=fig, ax=ax, offset=10)
    fig.tight_layout()
    fig.savefig('figures/fishtype_barplot.pdf')
    plt.close()
    def plot_and_savefig(self, out_path=None):
        sns.set_context('notebook')
        sns.set_style('white')

        plot_w = 3 + len(self.data['sample'].unique())
        plot_h = 3.5
        plots_per_row = 3

        n_plots = len(self.data.columns) - 2
        n_rows = ceil(n_plots / plots_per_row)
        n_cols = ceil(n_plots / n_rows)
        ax_ids = list(np.arange(n_plots) + 1)

        fig = plt.figure()
        fig.set_figheight(plot_h * n_rows)
        fig.set_figwidth(plot_w * n_cols)

        for i, category in enumerate(self.data.columns):
            if category in ['CATEGORY', 'sample']:
                continue

            ax = fig.add_subplot(n_rows, n_cols, ax_ids.pop(0))
            self.draw_ax(ax, category)
            if i == 0:
                ax.legend()
            else:
                ax.legend_.set_visible(False)

        plt.tight_layout()
        if out_path:
            plt.savefig(out_path, dpi=300, bbox_inches='tight')

        return ax
Пример #19
0
def plot_total_eui(theme):
    sns.set_style("whitegrid")
    sns.set_palette("Set2")
    sns.set_context("talk", font_scale=1.5)
    df_eui = read_eui_cnt('', theme, 'All Building')
    df_eui_wecm = read_eui_cnt('_wecm', theme, 'Building with ECM')
    df_eui_woutecm = read_eui_cnt('_woutecm', theme, 
                                  'Building without ECM')
    df_all = reduce(lambda x, y: pd.merge(x, y, on='Fiscal Year',
                                          how='inner'), [df_eui,
                                                         df_eui_wecm,
                                                         df_eui_woutecm])
    df_all = df_all[df_all['Fiscal Year'] < 2016]
    lines = []
    cols = list(df_all)
    cols.remove('Fiscal Year')
    bx = plt.axes()
    for x in cols:
        line, = plt.plot(df_all['Fiscal Year'], df_all[x], ls='-',
                         lw=2, marker='o')
        lines.append(line)
    plt.legend(lines, cols, loc='center left', 
               bbox_to_anchor=(1, 0.5), prop={'size':13})
    ylimit = 90
    plt.ylim((0, ylimit))
    plt.fill_between([2004.5, 2006.5], 0, ylimit, facecolor='gray',
                     alpha=0.2)
    plt.title('GSA Portfolio (A + I) Average EUI Trend')
    plt.xlabel('Fiscal Year')
    plt.ylabel(lb.ylabel_dict[theme])
    P.savefig(os.getcwd() + '/plot_FY_annual/ave_eui.png', dpi = 300,
              bbox_inches='tight')
    plt.close()
Пример #20
0
def create_histo(data):
    """
    gets a list of data and creates an histogram of this data.

    :param data: dictionary with fishtype as keys and np.array with EOD-Frequencies as values.
    """
    print 'creating histogramm ...'

    inch_factor = 2.54
    sns.set_context("poster")
    sns.axes_style('white')
    sns.set_style("ticks")
    fig, ax = plt.subplots(figsize=(15./inch_factor, 10./inch_factor))
    colors = ['salmon', 'cornflowerblue']

    for enu, curr_fishtype in enumerate(data.keys()):
        if len(data[curr_fishtype]) >= 4:
            hist, bins = np.histogram(data[curr_fishtype], bins=len(data[curr_fishtype])//4)
            width = 0.7 * (bins[1] - bins[0])
            center = (bins[:-1] + bins[1:]) / 2
            ax.bar(center, hist, align='center', width=width, alpha=0.8, facecolor=colors[enu], label=curr_fishtype)

    ax.set_ylabel('Counts', fontsize=14)
    ax.set_xlabel('Frequency [Hz]', fontsize=14)
    ax.set_xticks(np.arange(0, max(np.hstack(data.values()))+100, 250))
    ax.tick_params(axis='both', which='major', labelsize=12)
    ax.set_title('Distribution of EOD-Frequencies', fontsize=16)
    ax.legend(frameon=False, loc='best', fontsize=12)
    sns.despine(fig=fig, ax=ax, offset=10)
    fig.tight_layout()
    fig.savefig('figures/histo_of_eod_freqs.pdf')
    plt.close()
Пример #21
0
def plotLcurveall(alphaarr,datadif,constdif):
    """ 
    This will plot the L-curve for all of the lags
    """
    sns.set_style('whitegrid')
    sns.set_context('notebook')
    Nlag=datadif.shape[-1]
    
    
    nlagplot=4.
    nrows=int(sp.ceil(float(Nlag)/(2*nlagplot)))
    
    fig ,axmat= plt.subplots(nrows=nrows,ncols=2,facecolor='w',figsize=(8,4*nrows),sharey=True)
    axlist=axmat.flatten()
    
    
    for iaxn,iax in enumerate(axlist):
        strlist=[]
        handlist=[]
        for ilag in range(int(nlagplot)):
            curlag=int(iaxn*nlagplot+ilag)
            if curlag>=Nlag:
                break
            handlist.append(iax.plot(datadif[:,curlag],constdif[:,curlag])[0])
            strlist.append('Lag {0}'.format(curlag))
        iax.set_xscale('log')
        iax.set_yscale('log')
        iax.set_title('L Curve',fontsize=fs)
        iax.set_xlabel(r'$\|Ax-b\|_2$',fontsize=fs)
        iax.set_ylabel(r'$f(x)$',fontsize=fs)
        iax.legend(handlist,strlist,loc='upper right',fontsize='large')
    plt.tight_layout()
    return(fig,axlist)
Пример #22
0
def program_eui():
    df_eng = read_energy('good_energy')
    df_pro = pd.read_csv(master_dir + 'ecm_program_tidy.csv')
    programs = list(set(df_pro['ECM program'].tolist()))
    dfs = []
    sns.set_style("whitegrid")
    sns.set_palette("Set2", 8)
    sns.set_context("talk", font_scale=1.5)
    bx = plt.axes()
    lines = []
    labels = []
    programs.remove('Energy Star')
    for p in programs:
        buildings = df_pro[df_pro['ECM program'] == p]['Building Number'].unique()
        df_temp = df_eng.copy()
        df_temp = df_temp[df_temp['Building Number'].isin(buildings)]
        df_temp = df_temp[['Building Number', 'Fiscal Year', 'Gross Sq.Ft', 'Total Electric + Gas']]
        line = plot_eui_trend(df_temp, bx)
        lines.append(line)
        labels.append('{0} (n={1})'.format(p, len(df_temp['Building'
            ' Number'].unique())))
    plt.title('Energy Program EUI Trend')
    plt.ylabel(lb.ylabel_dict['eui'])
    plt.xlabel('Fiscal Year')
    plt.gca().set_ylim(bottom=0)
    ylimit = bx.get_ylim()
    plt.fill_between([2004.5, 2006.5], 0, ylimit, facecolor='gray',
                     alpha=0.2)
    plt.legend(lines, labels, loc='center left', 
               bbox_to_anchor=(1, 0.5), prop={'size':13})
    P.savefig(os.getcwd() + '/plot_FY_annual/program_trend.png', dpi =
              300, bbox_inches='tight')
    plt.close()
Пример #23
0
def deviance_curve(classifier, features, labels, metaparameter_name, param_range, metric='Accuracy',
                   n_folds=4, njobs=-1, fig_size=(16, 9)):

    training_scores, validation_scores = validation_curve(classifier,
                                                      features, labels,
                                                      metaparameter_name,
                                                      param_range,
                                                      n_jobs=njobs,
                                                      cv=n_folds, scoring=metric)

    training_scores_mean = np.mean(training_scores, axis=1)
    training_scores_std = np.std(training_scores, axis=1)
    validation_scores_mean = np.mean(validation_scores, axis=1)
    validation_scores_std = np.std(validation_scores, axis=1)
    sns.set_style("darkgrid")
    sns.set_context("notebook", font_scale=1.5, rc={"lines.linewidth": 2.5})
    plt.figure(num=None, figsize=fig_size, dpi=600, facecolor='w', edgecolor='k')
    plt.title("Validation Curve")
    plt.xlabel(metaparameter_name)
    plt.ylabel(metric)
    plt.xlim(np.min(param_range), np.max(param_range))
    plt.plot(param_range, training_scores_mean, label="Training " + metric, color="mediumblue")
    plt.fill_between(param_range, training_scores_mean - training_scores_std,
                     training_scores_mean + training_scores_std, alpha=0.2, color="lightskyblue")
    plt.plot(param_range, validation_scores_mean, label="validation " + metric,
                 color="coral")
    plt.fill_between(param_range, validation_scores_mean - validation_scores_std,
                     validation_scores_mean + validation_scores_std, alpha=0.2, color="lightcoral")
    plt.legend(loc="best")
    plt.show()
Пример #24
0
def plot_building_temp():
    sns.set_context("paper", font_scale=1.5)
    b = 'AZ0000FF'
    s = 'KTUS'
    filelist = glob.glob(os.getcwd() + '/csv_FY/testWeather/{0}*.csv'.format(b))
    dfs = [pd.read_csv(csv) for csv in filelist]
    col = 'eui_gas'
    dfs2 = [df[[col, 'month', 'year']] for df in dfs]
    df3 = (pd.concat(dfs2))

    temp = pd.read_csv(os.getcwd() + '/csv_FY/weather/weatherData_meanTemp.csv')
    temp['year'] = temp['Unnamed: 0'].map(lambda x: float(x[:4]))
    temp['month'] = temp['Unnamed: 0'].map(lambda x: float(x[5:7]))
    temp.set_index(pd.DatetimeIndex(temp['Unnamed: 0']), inplace=True)
    temp = temp[[s, 'month', 'year']]
    joint2 = pd.merge(df3, temp, on = ['year', 'month'], how = 'inner')
    joint2.to_csv(os.getcwd() + '/csv_FY/testWeather/test_temp.csv', index=False)

    sns.lmplot(s, col, data=joint2, col='year', fit_reg=False)
    plt.xlim((joint2[s].min() - 10, joint2[s].max() + 10))
    plt.ylim((0, joint2[col].max() + 0.1))
    P.savefig(os.getcwd() + '/csv_FY/testWeather/plot/scatter_temp_byyear.png', dpi=150)
    plt.close()

    joint2 = joint2[(2012 < joint2['year']) & (joint2['year'] < 2015)]
    sns.regplot(s, col, data=joint2, fit_reg=False)
    plt.xlim((joint2[s].min() - 10, joint2[s].max() + 10))
    plt.ylim((0, joint2[col].max() + 0.1))
    P.savefig(os.getcwd() + '/csv_FY/testWeather/plot/scatter_temp_1314.png', dpi=150)
    plt.close()
Пример #25
0
def plot_dfs_histogram(dfs_array, binwidth='FD'):
    """ Plots a histogram of the difference frequencies

    :param binwidth: select the size of the binwidth. use 'FD' for Freedman-Diaconis rule
    :param dfs_array: array-like. list of difference frequencies.
    """
    q75, q25 = np.percentile(abs(dfs_array), [75, 25])

    inch_factor = 2.54
    sns.set_context("poster")
    sns.axes_style('white')
    sns.set_style("ticks")
    fig, ax = plt.subplots(figsize=(15./inch_factor, 10./inch_factor))

    if binwidth == 'FD':
        ax.hist(dfs_array, bins=int(2*(q75-q25) * len(dfs_array)**(-1./3.)),
                facecolor='cornflowerblue', alpha=0.8)  # Freedman-Diaconis rule for binwidth
    else:
        ax.hist(dfs_array, bins=binwidth, color='cornflowerblue', alpha=0.8)

    # Plot Cosmetics

    ax.set_ylabel('Counts', fontsize=16)
    ax.set_xlabel('Possible Beat-Frequencies [Hz]', fontsize=14)
    ax.tick_params(axis='both', which='major', labelsize=12)
    ax.set_title('Distribution of Beat-Frequencies', fontsize=16)
    sns.despine(fig=fig, ax=ax, offset=10)
    fig.tight_layout()
    fig.savefig('figures/histo_of_dfs.pdf')
    plt.close()
Пример #26
0
def quickPlot(filename, path, datalist, xlabel="x", ylabel="y", xrange=["auto", "auto"], yrange=["auto", "auto"], yscale="linear", xscale="linear", col=["r", "b"]):
	"""Plots Data to .pdf File in Plots Folder Using matplotlib"""
	if "plots" not in os.listdir(path):
		os.mkdir(os.path.join(path, "plots"))
	coltab = col*10
	seaborn.set_context("notebook", rc={"lines.linewidth": 1.0})
	formatter = ScalarFormatter(useMathText=True)
	formatter.set_scientific(True)
	formatter.set_powerlimits((-2, 3))
	fig = Figure(figsize=(6, 6))
	ax = fig.add_subplot(111)
	for i, ydata in enumerate(datalist[1:]):
		ax.plot(datalist[0], ydata, c=coltab[i])
	ax.set_title(filename)
	ax.set_yscale(yscale)
	ax.set_xscale(xscale)
	ax.set_xlabel(xlabel)
	ax.set_ylabel(ylabel)
	if xrange[0] != "auto":
		ax.set_xlim(xmin=xrange[0])
	if xrange[1] != "auto":
		ax.set_xlim(xmax=xrange[1])
	if yrange[0] != "auto":
		ax.set_ylim(ymin=yrange[0])
	if yrange[1] != "auto":
		ax.set_ylim(ymax=yrange[1])
	if yscale == "linear":
		ax.yaxis.set_major_formatter(formatter)
	ax.xaxis.set_major_formatter(formatter)
	canvas = FigureCanvasPdf(fig)
	canvas.print_figure(os.path.join(path, "plots", filename+".pdf"))
	return
Пример #27
0
def plot_building_temp():
    sns.set_context("paper", font_scale=1.5)
    b = "AZ0000FF"
    s = "KTUS"
    filelist = glob.glob(os.getcwd() + "/csv_FY/testWeather/{0}*.csv".format(b))
    dfs = [pd.read_csv(csv) for csv in filelist]
    col = "eui_gas"
    dfs2 = [df[[col, "month", "year"]] for df in dfs]
    df3 = pd.concat(dfs2)

    temp = pd.read_csv(os.getcwd() + "/csv_FY/weather/weatherData_meanTemp.csv")
    temp["year"] = temp["Unnamed: 0"].map(lambda x: float(x[:4]))
    temp["month"] = temp["Unnamed: 0"].map(lambda x: float(x[5:7]))
    temp.set_index(pd.DatetimeIndex(temp["Unnamed: 0"]), inplace=True)
    temp = temp[[s, "month", "year"]]
    joint2 = pd.merge(df3, temp, on=["year", "month"], how="inner")
    joint2.to_csv(os.getcwd() + "/csv_FY/testWeather/test_temp.csv", index=False)

    sns.lmplot(s, col, data=joint2, col="year", fit_reg=False)
    plt.xlim((joint2[s].min() - 10, joint2[s].max() + 10))
    plt.ylim((0, joint2[col].max() + 0.1))
    P.savefig(os.getcwd() + "/csv_FY/testWeather/plot/scatter_temp_byyear.png", dpi=150)
    plt.close()

    joint2 = joint2[(2012 < joint2["year"]) & (joint2["year"] < 2015)]
    sns.regplot(s, col, data=joint2, fit_reg=False)
    plt.xlim((joint2[s].min() - 10, joint2[s].max() + 10))
    plt.ylim((0, joint2[col].max() + 0.1))
    P.savefig(os.getcwd() + "/csv_FY/testWeather/plot/scatter_temp_1314.png", dpi=150)
    plt.close()
Пример #28
0
def best_model_accuracy_bars(df,fig_path,metric,context):
    '''
    df: data frame
    context: paper,talk, notebook, poster
    '''
    
    font_size = {
        'paper':8,
        'poster':16,
        'notebook':10,
        'talk':13
    }
    sns.set_context(context)
    model_list,score_list = identify_best_of_each_model(df,metric)
    
    #size and position of bars
    bar_pos = np.arange(len(model_list))
    bar_size = score_list
    bar_labels = model_list
    
    #plot
    fig = plt.figure()
    plt.barh(bar_pos,bar_size, align='center', alpha=0.4)
    plt.yticks(bar_pos, bar_labels)
    plt.xticks([],[]) #no x-axis

    #Add data labels
    for x,y in zip(bar_size,bar_pos):
        plt.text(x+0.01, y, '%.2f' % x, ha='left', va='center',fontsize=font_size[context])
        
    pretty_metric = {'test_accuracy':'Test','best_score':'CV'}
    plt.title('Optimized %s Accuracy of Each Model' % pretty_metric[metric])
    fig.savefig(fig_path, bbox_inches='tight')
Пример #29
0
	def graphMetricDistn(self,metric,normType,plotType,resiType,save):
		# histogram/kde plot of density metric per atom
		# plotType is 'histogram' or 'kde'
		# resiType is 'all' or list of residue types
		# save is Boolian to save or not
		if plotType not in ('hist','kde'): return 'Unknown plotting type selected.. cannot plot..'
		if self.checkMetricPresent(self.atomList[0],metric,normType) is False: return # check metric valid

		sns.set_palette("deep", desat=.6)
		sns.set_context(rc={"figure.figsize": (10, 6)})
		fig = plt.figure()

		for i in range(self.getNumDatasets()):
			if resiType == 'all':
				datax = [atm.densMetric[metric][normType]['values'][i] for atm in self.atomList]
				self.plotHist(plotType,datax,'Dataset {}'.format(i))
			else:
				for res in resiType:
					datax = [atm.densMetric[metric][normType]['values'][i] for atm in self.atomList if atm.basetype == res]
					self.plotHist(plotType,datax,'Dataset {},{}'.format(i,res))

		plt.legend()
		plt.xlabel('{} D{} per atom'.format(normType,metric))
		plt.ylabel('Frequency')
		plt.title('{} D{} per atom, residues: {}'.format(normType,metric,resiType))
		if not save: 
			plt.show()
		else:
			fig.savefig('{}{}_D{}_{}.png'.format(self.outputDir,normType,metric,resiType))
Пример #30
0
def white():
    """
    Set plot aesthetics for 1D plots with white backgrounds.

    The plots are set up to have
    * White backgrounds
    * Grey minor and major gridlines
    * x and y ticks on bottom and left axes
    * Thin black outer border
    * Minor grid lines halfway between major ticks
    """

    pal = sns.color_palette('deep')
    sns.set_context('talk')
    sns.set_style('whitegrid', {'axes.edgecolor':'0.1',
                                'legend.frameon': True,
                                'xtick.color': '.15',
                                'xtick.major.size': 5,
                                'xtick.minor.size': 0.0,
                                'xtick.direction': 'out',
                                'ytick.color': '.15',
                                'ytick.major.size': 5,
                                'ytick.minor.size': 0.0,
                                'ytick.direction': 'out',
                                'axes.axisbelow': True,
                                'axes.linewidth': 0.4,
                                'font.family': 'sans-serif',
                                'font.sans-serif': ['Helvetica', 'Arial',
                                                    'Verdana', 'sans-serif']
                                })
Пример #31
0
# 防止pandas输出结果省略
pd.set_option('display.max_columns', 1000)
pd.set_option('display.width', 1000)
pd.set_option('display.max_colwidth', 1000)

df = pd.read_csv("./data/HR_.csv")
df = df.dropna(axis=0, how="any")
df = df[df["last_evaluation"] <= 1][df["salary"] != "nme"][
    df["department"] != "sale"]

import seaborn as sns
import matplotlib.pyplot as plt

#seaborn设置样式
sns.set_style(style="whitegrid")
sns.set_context(context="poster", font_scale=0.5)
sns.set_palette("summer")

f = plt.figure()
#subplot(numRows, numCols, plotNum)
#图表的整个绘图区域被分成 numRows 行和 numCols 列
#然后按照从左到右,从上到下的顺序对每个子区域进行编号
f.add_subplot(1, 3, 1)
#画分布图
#kde=False表示没有分布曲线,hist=False表示没有直方图
sns.distplot(df["satisfaction_level"], bins=10)
f.add_subplot(1, 3, 2)
sns.distplot(df["last_evaluation"], bins=10)
f.add_subplot(1, 3, 3)
sns.distplot(df["average_monthly_hours"], bins=10)
plt.show()
import numpy as np
import matplotlib.pyplot as pyplot
import pandas as pd
import seaborn as sns

data = pd.read_csv('browsertime.csv')
result = data
result['protocol'].replace(["HTTP1.1/TLS", "HTTP2"], ["H1s", "H2"],
                           inplace=True)
result["pageLoadTime"] = result["pageLoadTime"] / 1000
brs = result.browser.unique()
for br in brs:
    pyplot.clf()
    sns.set_context('paper')
    pyplot.figure(figsize=(18, 10))
    sns.set(font_scale=2.2)
    #sns.set_style("white")
    g = sns.boxplot(x="Ops",
                    y="pageLoadTime",
                    hue="protocol",
                    data=result[(result.browser == br)],
                    order=[
                        "Telia (SE)", "Telenor (SE)", "Tre (SE)",
                        "Telenor (NO)", "Telia (NO)", "ICE (NO)", "TIM (IT)",
                        "Vodafone (IT)", "Wind (IT)", "Orange (ES)",
                        "Yoigo (ES)"
                    ],
                    hue_order=["H1s", "H2"],
                    palette="muted",
                    showfliers=False)
    pyplot.xticks(rotation=10)
Пример #33
0
import re
import numpy as np
import pandas as pd
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
pd.set_option('display.width', 200)
import glob
from matplotlib import pyplot as plt
from matplotlib import ticker as ticker
import seaborn as sns
sns.set_style('white', {
    'axes.grid': True,
    'xtick.bottom': True,
    'ytick.left': True
})
sns.set_context(rc={'patch.linewidth': '0.0'})
sns.set_palette(
    sns.color_palette(['#ce0e2d', '#005cb9', '#f5a800', '#45c2b1', '#035c67']))

files = sorted(glob.glob('*_coverage_hist.tsv'))

names = [i.replace('_coverage_hist.tsv', '') for i in files]
names = [i.replace('_dmarked_coverage_hist.tsv', '') for i in names]
print(names)
print('')

groups = [re.sub(r'_L.*_coverage.tsv', '', i) for i in files]
print(groups)
print('')

# colnames = ['metric1', 'TOTAL_READS', 'metric3', 'metric4', 'metric5', 'metric6', 'PCT_PF_READS_ALIGNED', 'metric8', 'metric9', 'metric10', 'metric11', 'metric12', 'PF_MISMATCH_RATE', 'PF_HQ_ERROR_RATE', 'PF_INDEL_RATE', 'metric16', 'metric17', 'metric18', 'metric19', 'metric20', 'metric21', 'metric22', 'PCT_CHIMERAS', 'PCT_ADAPTER', 'metric25', 'metric26', 'metric27']
Пример #34
0
def main():
    sns.set_style('white')
    sns.set_context('poster')
    parser = argparse.ArgumentParser(
        description='%s Parameters' % __tool_name__,
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("-m",
                        "--matrix",
                        dest="input_filename",
                        default=None,
                        help="input file name",
                        metavar="FILE")
    parser.add_argument("-l",
                        "--cell_labels",
                        dest="cell_label_filename",
                        default=None,
                        help="filename of cell labels")
    parser.add_argument("-c",
                        "--cell_labels_colors",
                        dest="cell_label_color_filename",
                        default=None,
                        help="filename of cell label colors")
    parser.add_argument(
        "-s",
        "--select_features",
        dest="s_method",
        default='LOESS',
        help=
        "LOESS,PCA or all: Select variable genes using LOESS or principal components using PCA or all the genes are kept"
    )
    parser.add_argument("--TG",
                        "--detect_TG_genes",
                        dest="flag_gene_TG_detection",
                        action="store_true",
                        help="detect transition genes automatically")
    parser.add_argument("--DE",
                        "--detect_DE_genes",
                        dest="flag_gene_DE_detection",
                        action="store_true",
                        help="detect DE genes automatically")
    parser.add_argument("--LG",
                        "--detect_LG_genes",
                        dest="flag_gene_LG_detection",
                        action="store_true",
                        help="detect leaf genes automatically")
    parser.add_argument(
        "-g",
        "--genes",
        dest="genes",
        default=None,
        help=
        "genes to visualize, it can either be filename which contains all the genes in one column or a set of gene names separated by comma"
    )
    parser.add_argument(
        "-p",
        "--use_precomputed",
        dest="use_precomputed",
        action="store_true",
        help=
        "use precomputed data files without re-computing structure learning part"
    )
    parser.add_argument("--new",
                        dest="new_filename",
                        default=None,
                        help="file name of data to be mapped")
    parser.add_argument("--new_l",
                        dest="new_label_filename",
                        default=None,
                        help="filename of new cell labels")
    parser.add_argument("--new_c",
                        dest="new_label_color_filename",
                        default=None,
                        help="filename of new cell label colors")
    parser.add_argument("--log2",
                        dest="flag_log2",
                        action="store_true",
                        help="perform log2 transformation")
    parser.add_argument("--norm",
                        dest="flag_norm",
                        action="store_true",
                        help="normalize data based on library size")
    parser.add_argument("--atac",
                        dest="flag_atac",
                        action="store_true",
                        help="indicate scATAC-seq data")
    parser.add_argument(
        "--n_processes",
        dest="n_processes",
        type=int,
        default=multiprocessing.cpu_count(),
        help=
        "Specify the number of processes to use. (default, all the available cores)"
    )
    parser.add_argument(
        "--loess_frac",
        dest="loess_frac",
        type=float,
        default=0.1,
        help="The fraction of the data used in LOESS regression")
    parser.add_argument(
        "--loess_cutoff",
        dest="loess_cutoff",
        type=int,
        default=95,
        help=
        "the percentile used in variable gene selection based on LOESS regression"
    )
    parser.add_argument("--pca_first_PC",
                        dest="flag_first_PC",
                        action="store_true",
                        help="keep first PC")
    parser.add_argument("--pca_n_PC",
                        dest="pca_n_PC",
                        type=int,
                        default=15,
                        help="The number of selected PCs,it's 15 by default")
    parser.add_argument("--lle_neighbours",
                        dest="lle_n_nb_percent",
                        type=float,
                        default=0.1,
                        help="LLE neighbour percent ")
    parser.add_argument("--lle_components",
                        dest="lle_n_component",
                        type=int,
                        default=3,
                        help="number of components for LLE space ")
    parser.add_argument(
        "--clustering",
        dest="clustering",
        default='kmeans',
        help=
        "Clustering method used for seeding the intial structure, choose from 'ap','kmeans','sc'"
    )
    parser.add_argument("--damping",
                        dest="damping",
                        type=float,
                        default=0.75,
                        help="Affinity Propagation: damping factor")
    parser.add_argument(
        "--n_clusters",
        dest="n_clusters",
        type=int,
        default=10,
        help="Number of clusters for spectral clustering or kmeans")
    parser.add_argument("--EPG_n_nodes",
                        dest="EPG_n_nodes",
                        type=int,
                        default=50,
                        help=" Number of nodes for elastic principal graph")
    parser.add_argument(
        "--EPG_lambda",
        dest="EPG_lambda",
        type=float,
        default=0.02,
        help="lambda parameter used to compute the elastic energy")
    parser.add_argument("--EPG_mu",
                        dest="EPG_mu",
                        type=float,
                        default=0.1,
                        help="mu parameter used to compute the elastic energy")
    parser.add_argument(
        "--EPG_trimmingradius",
        dest="EPG_trimmingradius",
        type=float,
        default=np.inf,
        help="maximal distance of point from a node to affect its embedment")
    parser.add_argument(
        "--EPG_alpha",
        dest="EPG_alpha",
        type=float,
        default=0.02,
        help=
        "positive numeric, the value of the alpha parameter of the penalized elastic energy"
    )
    parser.add_argument("--EPG_collapse",
                        dest="flag_EPG_collapse",
                        action="store_true",
                        help="collapsing small branches")
    parser.add_argument(
        "--EPG_collapse_mode",
        dest="EPG_collapse_mode",
        default="PointNumber",
        help=
        "the mode used to collapse branches. PointNumber,PointNumber_Extrema, PointNumber_Leaves,EdgesNumber or EdgesLength"
    )
    parser.add_argument(
        "--EPG_collapse_par",
        dest="EPG_collapse_par",
        type=float,
        default=5,
        help=
        "positive numeric, the cotrol paramter used for collapsing small branches"
    )
    parser.add_argument("--disable_EPG_optimize",
                        dest="flag_disable_EPG_optimize",
                        action="store_true",
                        help="disable optimizing branching")
    parser.add_argument("--EPG_shift",
                        dest="flag_EPG_shift",
                        action="store_true",
                        help="shift branching point ")
    parser.add_argument(
        "--EPG_shift_mode",
        dest="EPG_shift_mode",
        default='NodeDensity',
        help=
        "the mode to use to shift the branching points NodePoints or NodeDensity"
    )
    parser.add_argument(
        "--EPG_shift_DR",
        dest="EPG_shift_DR",
        type=float,
        default=0.05,
        help=
        "positive numeric, the radius to be used when computing point density if EPG_shift_mode is NodeDensity"
    )
    parser.add_argument(
        "--EPG_shift_maxshift",
        dest="EPG_shift_maxshift",
        type=int,
        default=5,
        help=
        "positive integer, the maxium distance (as number of edges) to consider when exploring the branching point neighborhood"
    )
    parser.add_argument("--disable_EPG_ext",
                        dest="flag_disable_EPG_ext",
                        action="store_true",
                        help="disable extending leaves with additional nodes")
    parser.add_argument(
        "--EPG_ext_mode",
        dest="EPG_ext_mode",
        default='QuantDists',
        help=
        " the mode used to extend the graph,QuantDists, QuantCentroid or WeigthedCentroid"
    )
    parser.add_argument(
        "--EPG_ext_par",
        dest="EPG_ext_par",
        type=float,
        default=0.5,
        help=
        "the control parameter used for contribution of the different data points when extending leaves with nodes"
    )
    parser.add_argument("--DE_zscore_cutoff",
                        dest="DE_zscore_cutoff",
                        default=2,
                        help="Differentially Expressed Genes z-score cutoff")
    parser.add_argument(
        "--DE_logfc_cutoff",
        dest="DE_logfc_cutoff",
        default=0.25,
        help="Differentially Expressed Genes log fold change cutoff")
    parser.add_argument("--TG_spearman_cutoff",
                        dest="TG_spearman_cutoff",
                        default=0.4,
                        help="Transition Genes Spearman correlation cutoff")
    parser.add_argument("--TG_logfc_cutoff",
                        dest="TG_logfc_cutoff",
                        default=0.25,
                        help="Transition Genes log fold change cutoff")
    parser.add_argument("--LG_zscore_cutoff",
                        dest="LG_zscore_cutoff",
                        default=1.5,
                        help="Leaf Genes z-score cutoff")
    parser.add_argument("--LG_pvalue_cutoff",
                        dest="LG_pvalue_cutoff",
                        default=1e-2,
                        help="Leaf Genes p value cutoff")
    parser.add_argument(
        "--umap",
        dest="flag_umap",
        action="store_true",
        help="whether to use UMAP for visualization (default: No)")
    parser.add_argument("-r",
                        dest="root",
                        default=None,
                        help="root node for subwaymap_plot and stream_plot")
    parser.add_argument("--stream_log_view",
                        dest="flag_stream_log_view",
                        action="store_true",
                        help="use log2 scale for y axis of stream_plot")
    parser.add_argument("-o",
                        "--output_folder",
                        dest="output_folder",
                        default=None,
                        help="Output folder")
    parser.add_argument("--for_web",
                        dest="flag_web",
                        action="store_true",
                        help="Output files for website")
    parser.add_argument(
        "--n_genes",
        dest="n_genes",
        type=int,
        default=5,
        help=
        "Number of top genes selected from each output marker gene file for website gene visualization"
    )

    args = parser.parse_args()
    if (args.input_filename is None) and (args.new_filename is None):
        parser.error("at least one of -m, --new required")

    new_filename = args.new_filename
    new_label_filename = args.new_label_filename
    new_label_color_filename = args.new_label_color_filename
    flag_stream_log_view = args.flag_stream_log_view
    flag_gene_TG_detection = args.flag_gene_TG_detection
    flag_gene_DE_detection = args.flag_gene_DE_detection
    flag_gene_LG_detection = args.flag_gene_LG_detection
    flag_web = args.flag_web
    flag_first_PC = args.flag_first_PC
    flag_umap = args.flag_umap
    genes = args.genes
    DE_zscore_cutoff = args.DE_zscore_cutoff
    DE_logfc_cutoff = args.DE_logfc_cutoff
    TG_spearman_cutoff = args.TG_spearman_cutoff
    TG_logfc_cutoff = args.TG_logfc_cutoff
    LG_zscore_cutoff = args.LG_zscore_cutoff
    LG_pvalue_cutoff = args.LG_pvalue_cutoff
    root = args.root

    input_filename = args.input_filename
    cell_label_filename = args.cell_label_filename
    cell_label_color_filename = args.cell_label_color_filename
    s_method = args.s_method
    use_precomputed = args.use_precomputed
    n_processes = args.n_processes
    loess_frac = args.loess_frac
    loess_cutoff = args.loess_cutoff
    pca_n_PC = args.pca_n_PC
    flag_log2 = args.flag_log2
    flag_norm = args.flag_norm
    flag_atac = args.flag_atac
    lle_n_nb_percent = args.lle_n_nb_percent  #LLE neighbour percent
    lle_n_component = args.lle_n_component  #LLE dimension reduction
    clustering = args.clustering
    damping = args.damping
    n_clusters = args.n_clusters
    EPG_n_nodes = args.EPG_n_nodes
    EPG_lambda = args.EPG_lambda
    EPG_mu = args.EPG_mu
    EPG_trimmingradius = args.EPG_trimmingradius
    EPG_alpha = args.EPG_alpha
    flag_EPG_collapse = args.flag_EPG_collapse
    EPG_collapse_mode = args.EPG_collapse_mode
    EPG_collapse_par = args.EPG_collapse_par
    flag_EPG_shift = args.flag_EPG_shift
    EPG_shift_mode = args.EPG_shift_mode
    EPG_shift_DR = args.EPG_shift_DR
    EPG_shift_maxshift = args.EPG_shift_maxshift
    flag_disable_EPG_optimize = args.flag_disable_EPG_optimize
    flag_disable_EPG_ext = args.flag_disable_EPG_ext
    EPG_ext_mode = args.EPG_ext_mode
    EPG_ext_par = args.EPG_ext_par
    output_folder = args.output_folder  #work directory
    n_genes = args.n_genes

    if (flag_web):
        flag_savefig = False
    else:
        flag_savefig = True
    gene_list = []
    if (genes != None):
        if (os.path.exists(genes)):
            gene_list = pd.read_csv(genes,
                                    sep='\t',
                                    header=None,
                                    index_col=None,
                                    compression='gzip' if genes.split('.')[-1]
                                    == 'gz' else None).iloc[:, 0].tolist()
            gene_list = list(set(gene_list))
        else:
            gene_list = genes.split(',')
        print('Genes to visualize: ')
        print(gene_list)
    if (new_filename is None):
        if (output_folder == None):
            workdir = os.path.join(os.getcwd(), 'stream_result')
        else:
            workdir = output_folder
        if (use_precomputed):
            print('Importing the precomputed pkl file...')
            adata = st.read(file_name='stream_result.pkl',
                            file_format='pkl',
                            file_path=workdir,
                            workdir=workdir)
        else:
            if (flag_atac):
                print('Reading in atac zscore matrix...')
                adata = st.read(file_name=input_filename,
                                workdir=workdir,
                                experiment='atac-seq')
            else:
                adata = st.read(file_name=input_filename, workdir=workdir)
                print('Input: ' + str(adata.obs.shape[0]) + ' cells, ' +
                      str(adata.var.shape[0]) + ' genes')
            adata.var_names_make_unique()
            adata.obs_names_make_unique()
            if (cell_label_filename != None):
                st.add_cell_labels(adata, file_name=cell_label_filename)
            else:
                st.add_cell_labels(adata)
            if (cell_label_color_filename != None):
                st.add_cell_colors(adata, file_name=cell_label_color_filename)
            else:
                st.add_cell_colors(adata)
            if (flag_atac):
                print('Selecting top principal components...')
                st.select_top_principal_components(adata,
                                                   n_pc=pca_n_PC,
                                                   first_pc=flag_first_PC,
                                                   save_fig=True)
                st.dimension_reduction(adata,
                                       n_components=lle_n_component,
                                       nb_pct=lle_n_nb_percent,
                                       n_jobs=n_processes,
                                       feature='top_pcs')
            else:
                if (flag_norm):
                    st.normalize_per_cell(adata)
                if (flag_log2):
                    st.log_transform(adata)
                if (s_method != 'all'):
                    print('Filtering genes...')
                    st.filter_genes(adata, min_num_cells=5)
                    print('Removing mitochondrial genes...')
                    st.remove_mt_genes(adata)
                    if (s_method == 'LOESS'):
                        print('Selecting most variable genes...')
                        st.select_variable_genes(adata,
                                                 loess_frac=loess_frac,
                                                 percentile=loess_cutoff,
                                                 save_fig=True)
                        pd.DataFrame(adata.uns['var_genes']).to_csv(
                            os.path.join(workdir,
                                         'selected_variable_genes.tsv'),
                            sep='\t',
                            index=None,
                            header=False)
                        st.dimension_reduction(adata,
                                               n_components=lle_n_component,
                                               nb_pct=lle_n_nb_percent,
                                               n_jobs=n_processes,
                                               feature='var_genes')
                    if (s_method == 'PCA'):
                        print('Selecting top principal components...')
                        st.select_top_principal_components(
                            adata,
                            n_pc=pca_n_PC,
                            first_pc=flag_first_PC,
                            save_fig=True)
                        st.dimension_reduction(adata,
                                               n_components=lle_n_component,
                                               nb_pct=lle_n_nb_percent,
                                               n_jobs=n_processes,
                                               feature='top_pcs')
                else:
                    print('Keep all the genes...')
                    st.dimension_reduction(adata,
                                           n_components=lle_n_component,
                                           nb_pct=lle_n_nb_percent,
                                           n_jobs=n_processes,
                                           feature='all')
            st.plot_dimension_reduction(adata, save_fig=flag_savefig)
            st.seed_elastic_principal_graph(adata,
                                            clustering=clustering,
                                            damping=damping,
                                            n_clusters=n_clusters)
            st.plot_branches(
                adata,
                save_fig=flag_savefig,
                fig_name='seed_elastic_principal_graph_skeleton.pdf')
            st.plot_branches_with_cells(
                adata,
                save_fig=flag_savefig,
                fig_name='seed_elastic_principal_graph.pdf')

            st.elastic_principal_graph(adata,
                                       epg_n_nodes=EPG_n_nodes,
                                       epg_lambda=EPG_lambda,
                                       epg_mu=EPG_mu,
                                       epg_trimmingradius=EPG_trimmingradius,
                                       epg_alpha=EPG_alpha)
            st.plot_branches(adata,
                             save_fig=flag_savefig,
                             fig_name='elastic_principal_graph_skeleton.pdf')
            st.plot_branches_with_cells(adata,
                                        save_fig=flag_savefig,
                                        fig_name='elastic_principal_graph.pdf')
            if (not flag_disable_EPG_optimize):
                st.optimize_branching(adata,
                                      epg_trimmingradius=EPG_trimmingradius)
                st.plot_branches(
                    adata,
                    save_fig=flag_savefig,
                    fig_name='optimizing_elastic_principal_graph_skeleton.pdf')
                st.plot_branches_with_cells(
                    adata,
                    save_fig=flag_savefig,
                    fig_name='optimizing_elastic_principal_graph.pdf')
            if (flag_EPG_shift):
                st.shift_branching(adata,
                                   epg_shift_mode=EPG_shift_mode,
                                   epg_shift_radius=EPG_shift_DR,
                                   epg_shift_max=EPG_shift_maxshift,
                                   epg_trimmingradius=EPG_trimmingradius)
                st.plot_branches(
                    adata,
                    save_fig=flag_savefig,
                    fig_name='shifting_elastic_principal_graph_skeleton.pdf')
                st.plot_branches_with_cells(
                    adata,
                    save_fig=flag_savefig,
                    fig_name='shifting_elastic_principal_graph.pdf')
            if (flag_EPG_collapse):
                st.prune_elastic_principal_graph(
                    adata,
                    epg_collapse_mode=EPG_collapse_mode,
                    epg_collapse_par=EPG_collapse_par,
                    epg_trimmingradius=EPG_trimmingradius)
                st.plot_branches(
                    adata,
                    save_fig=flag_savefig,
                    fig_name='pruning_elastic_principal_graph_skeleton.pdf')
                st.plot_branches_with_cells(
                    adata,
                    save_fig=flag_savefig,
                    fig_name='pruning_elastic_principal_graph.pdf')
            if (not flag_disable_EPG_ext):
                st.extend_elastic_principal_graph(
                    adata,
                    epg_ext_mode=EPG_ext_mode,
                    epg_ext_par=EPG_ext_par,
                    epg_trimmingradius=EPG_trimmingradius)
                st.plot_branches(
                    adata,
                    save_fig=flag_savefig,
                    fig_name='extending_elastic_principal_graph_skeleton.pdf')
                st.plot_branches_with_cells(
                    adata,
                    save_fig=flag_savefig,
                    fig_name='extending_elastic_principal_graph.pdf')
            st.plot_branches(
                adata,
                save_fig=flag_savefig,
                fig_name='finalized_elastic_principal_graph_skeleton.pdf')
            st.plot_branches_with_cells(
                adata,
                save_fig=flag_savefig,
                fig_name='finalized_elastic_principal_graph.pdf')
            st.plot_flat_tree(adata, save_fig=flag_savefig)
            if (flag_umap):
                print('UMAP visualization based on top MLLE components...')
                st.plot_visualization_2D(adata,
                                         save_fig=flag_savefig,
                                         fig_name='umap_cells')
                st.plot_visualization_2D(adata,
                                         color_by='branch',
                                         save_fig=flag_savefig,
                                         fig_name='umap_branches')
            if (root is None):
                print('Visualization of subwaymap and stream plots...')
                flat_tree = adata.uns['flat_tree']
                list_node_start = [
                    value for key, value in nx.get_node_attributes(
                        flat_tree, 'label').items()
                ]
                for ns in list_node_start:
                    if (flag_web):
                        st.subwaymap_plot(adata,
                                          percentile_dist=100,
                                          root=ns,
                                          save_fig=flag_savefig)
                        st.stream_plot(adata,
                                       root=ns,
                                       fig_size=(8, 8),
                                       save_fig=True,
                                       flag_log_view=flag_stream_log_view,
                                       fig_legend=False,
                                       fig_name='stream_plot.png')
                    else:
                        st.subwaymap_plot(adata,
                                          percentile_dist=100,
                                          root=ns,
                                          save_fig=flag_savefig)
                        st.stream_plot(adata,
                                       root=ns,
                                       fig_size=(8, 8),
                                       save_fig=flag_savefig,
                                       flag_log_view=flag_stream_log_view)
            else:
                st.subwaymap_plot(adata,
                                  percentile_dist=100,
                                  root=root,
                                  save_fig=flag_savefig)
                st.stream_plot(adata,
                               root=root,
                               fig_size=(8, 8),
                               save_fig=flag_savefig,
                               flag_log_view=flag_stream_log_view)
            output_cell_info(adata)
            if (flag_web):
                output_for_website(adata)
            st.write(adata)

        if (flag_gene_TG_detection):
            print('Identifying transition genes...')
            st.detect_transistion_genes(adata,
                                        cutoff_spearman=TG_spearman_cutoff,
                                        cutoff_logfc=TG_logfc_cutoff,
                                        n_jobs=n_processes)
            if (flag_web):
                ## Plot top5 genes
                flat_tree = adata.uns['flat_tree']
                list_node_start = [
                    value for key, value in nx.get_node_attributes(
                        flat_tree, 'label').items()
                ]
                gene_list = []
                for x in adata.uns['transition_genes'].keys():
                    gene_list = gene_list + adata.uns['transition_genes'][
                        x].index[:n_genes].tolist()
                gene_list = np.unique(gene_list)
                for ns in list_node_start:
                    output_for_website_subwaymap_gene(adata, gene_list)
                    st.stream_plot_gene(adata,
                                        root=ns,
                                        fig_size=(8, 8),
                                        genes=gene_list,
                                        save_fig=True,
                                        flag_log_view=flag_stream_log_view,
                                        fig_format='png')
            else:
                st.plot_transition_genes(adata, save_fig=flag_savefig)

        if (flag_gene_DE_detection):
            print('Identifying differentially expressed genes...')
            st.detect_de_genes(adata,
                               cutoff_zscore=DE_logfc_cutoff,
                               cutoff_logfc=DE_logfc_cutoff,
                               n_jobs=n_processes)
            if (flag_web):
                flat_tree = adata.uns['flat_tree']
                list_node_start = [
                    value for key, value in nx.get_node_attributes(
                        flat_tree, 'label').items()
                ]
                gene_list = []
                for x in adata.uns['de_genes_greater'].keys():
                    gene_list = gene_list + adata.uns['de_genes_greater'][
                        x].index[:n_genes].tolist()
                for x in adata.uns['de_genes_less'].keys():
                    gene_list = gene_list + adata.uns['de_genes_less'][
                        x].index[:n_genes].tolist()
                gene_list = np.unique(gene_list)
                for ns in list_node_start:
                    output_for_website_subwaymap_gene(adata, gene_list)
                    st.stream_plot_gene(adata,
                                        root=ns,
                                        fig_size=(8, 8),
                                        genes=gene_list,
                                        save_fig=True,
                                        flag_log_view=flag_stream_log_view,
                                        fig_format='png')
            else:
                st.plot_de_genes(adata, save_fig=flag_savefig)

        if (flag_gene_LG_detection):
            print('Identifying leaf genes...')
            st.detect_leaf_genes(adata,
                                 cutoff_zscore=LG_zscore_cutoff,
                                 cutoff_pvalue=LG_pvalue_cutoff,
                                 n_jobs=n_processes)
            if (flag_web):
                ## Plot top5 genes
                flat_tree = adata.uns['flat_tree']
                list_node_start = [
                    value for key, value in nx.get_node_attributes(
                        flat_tree, 'label').items()
                ]
                gene_list = []
                for x in adata.uns['leaf_genes'].keys():
                    gene_list = gene_list + adata.uns['leaf_genes'][
                        x].index[:n_genes].tolist()
                gene_list = np.unique(gene_list)
                for ns in list_node_start:
                    output_for_website_subwaymap_gene(adata, gene_list)
                    st.stream_plot_gene(adata,
                                        root=ns,
                                        fig_size=(8, 8),
                                        genes=gene_list,
                                        save_fig=True,
                                        flag_log_view=flag_stream_log_view,
                                        fig_format='png')

        if ((genes != None) and (len(gene_list) > 0)):
            print('Visualizing genes...')
            flat_tree = adata.uns['flat_tree']
            list_node_start = [
                value for key, value in nx.get_node_attributes(
                    flat_tree, 'label').items()
            ]
            if (root is None):
                for ns in list_node_start:
                    if (flag_web):
                        output_for_website_subwaymap_gene(adata, gene_list)
                        st.stream_plot_gene(adata,
                                            root=ns,
                                            fig_size=(8, 8),
                                            genes=gene_list,
                                            save_fig=True,
                                            flag_log_view=flag_stream_log_view,
                                            fig_format='png')
                    else:
                        st.subwaymap_plot_gene(adata,
                                               percentile_dist=100,
                                               root=ns,
                                               genes=gene_list,
                                               save_fig=flag_savefig)
                        st.stream_plot_gene(adata,
                                            root=ns,
                                            fig_size=(8, 8),
                                            genes=gene_list,
                                            save_fig=flag_savefig,
                                            flag_log_view=flag_stream_log_view)
            else:
                if (flag_web):
                    output_for_website_subwaymap_gene(adata, gene_list)
                    st.stream_plot_gene(adata,
                                        root=root,
                                        fig_size=(8, 8),
                                        genes=gene_list,
                                        save_fig=True,
                                        flag_log_view=flag_stream_log_view,
                                        fig_format='png')
                else:
                    st.subwaymap_plot_gene(adata,
                                           percentile_dist=100,
                                           root=root,
                                           genes=gene_list,
                                           save_fig=flag_savefig)
                    st.stream_plot_gene(adata,
                                        root=root,
                                        fig_size=(8, 8),
                                        genes=gene_list,
                                        save_fig=flag_savefig,
                                        flag_log_view=flag_stream_log_view)

    else:
        print('Starting mapping procedure...')
        if (output_folder == None):
            workdir_ref = os.path.join(os.getcwd(), 'stream_result')
        else:
            workdir_ref = output_folder
        adata = st.read(file_name='stream_result.pkl',
                        file_format='pkl',
                        file_path=workdir_ref,
                        workdir=workdir_ref)
        workdir = os.path.join(workdir_ref, os.pardir, 'mapping_result')
        adata_new = st.read(file_name=new_filename, workdir=workdir)
        st.add_cell_labels(adata_new, file_name=new_label_filename)
        st.add_cell_colors(adata_new, file_name=new_label_color_filename)
        if (s_method == 'LOESS'):
            st.map_new_data(adata, adata_new, feature='var_genes')
        if (s_method == 'all'):
            st.map_new_data(adata, adata_new, feature='all')
        if (flag_umap):
            st.plot_visualization_2D(adata,
                                     adata_new=adata_new,
                                     use_precomputed=False,
                                     save_fig=flag_savefig,
                                     fig_name='umap_new_cells')
            st.plot_visualization_2D(adata,
                                     adata_new=adata_new,
                                     show_all_colors=True,
                                     save_fig=flag_savefig,
                                     fig_name='umap_all_cells')
            st.plot_visualization_2D(adata,
                                     adata_new=adata_new,
                                     color_by='branch',
                                     save_fig=flag_savefig,
                                     fig_name='umap_branches')
        if (root is None):
            flat_tree = adata.uns['flat_tree']
            list_node_start = [
                value for key, value in nx.get_node_attributes(
                    flat_tree, 'label').items()
            ]
            for ns in list_node_start:
                st.subwaymap_plot(adata,
                                  adata_new=adata_new,
                                  percentile_dist=100,
                                  show_all_cells=False,
                                  root=ns,
                                  save_fig=flag_savefig)
                st.stream_plot(adata,
                               adata_new=adata_new,
                               show_all_colors=False,
                               root=ns,
                               fig_size=(8, 8),
                               save_fig=flag_savefig,
                               flag_log_view=flag_stream_log_view)
        else:
            st.subwaymap_plot(adata,
                              adata_new=adata_new,
                              percentile_dist=100,
                              show_all_cells=False,
                              root=root,
                              save_fig=flag_savefig)
            st.stream_plot(adata,
                           adata_new=adata_new,
                           show_all_colors=False,
                           root=root,
                           fig_size=(8, 8),
                           save_fig=flag_savefig,
                           flag_log_view=flag_stream_log_view)
        if ((genes != None) and (len(gene_list) > 0)):
            if (root is None):
                for ns in list_node_start:
                    st.subwaymap_plot_gene(adata,
                                           adata_new=adata_new,
                                           percentile_dist=100,
                                           root=ns,
                                           save_fig=flag_savefig,
                                           flag_log_view=flag_stream_log_view)
            else:
                st.subwaymap_plot_gene(adata,
                                       adata_new=adata_new,
                                       percentile_dist=100,
                                       root=root,
                                       save_fig=flag_savefig,
                                       flag_log_view=flag_stream_log_view)
        st.write(adata_new, file_name='stream_mapping_result.pkl')
    print('Finished computation.')
Пример #35
0
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, copy, time
from torch.autograd import Variable
import matplotlib.pyplot as plt
import seaborn

seaborn.set_context(context="talk")


class Batch:
    "Object for holding a batch of data with mask during training."

    def __init__(self, src, trg=None, pad=0):
        self.src = src
        self.src_mask = (src != pad).unsqueeze(-2)
        if trg is not None:
            self.trg = trg[:, :-1]
            self.trg_y = trg[:, 1:]
            self.trg_mask = self.make_std_mask(self.trg, pad)
            self.ntokens = (self.trg_y != pad).data.sum()

    @staticmethod
    def make_std_mask(tgt, pad):
        "Create a mask to hide padding and future words."
        tgt_mask = (tgt != pad).unsqueeze(-2)
        tgt_mask = tgt_mask & Variable(
            subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
        return tgt_mask
Пример #36
0
"""
Created on Fri May  8 16:30:26 2020

@author: kevin
"""

import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

import seaborn as sns
color_names = ["windows blue", "red", "amber", "faded green"]
colors = sns.xkcd_palette(color_names)
sns.set_style("white")
sns.set_context("talk")


# %% functions
def Threshold(u, D, mu, vv):
    """
    Spiking threshold
    """
    output_dim, N = D.shape
    ss = np.zeros(N)
    Di = np.linalg.norm(D, axis=0)**2
    Ti = 0.5 * (Di + mu + vv)  #threshold
    ss[np.where(u > Ti)[0]] = 1
    return ss

Пример #37
0
import os
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import numpy as np
from cycler import cycler

# Set seaborn to override matplotlib for plot output
sns.set()
sns.set_style("white")
# The four preset contexts, in order of relative size, are paper, notebook, talk, and poster.
# The notebook style is the default
sns.set_context("poster", font_scale=1.1)

def importfromjson(path, name):
    filename = os.path.join(path, name + '.json')
    importeddata = pd.read_json(filename)

    return importeddata


def plotkde(df, directory, name, plotextension, grouparg, plotarg):
    print 'Plotting kde of %s' % plotarg

    # Create a saving name format/directory
    savedir = os.path.join(directory, 'Plots')
    if not os.path.exists(savedir):
        os.makedirs(savedir)

    # Plot and save figures
    savename = os.path.join(savedir, name + plotarg + plotextension)
Пример #38
0
def visualize_perf_against_hparam(
    hparam_list: List[float],
    hparam_name: str,
    args_or_args_list,
    num_total_runs: int,
    tasks=None,
    plot_individual=True,
    plot_ablation=False,
    xlabel=None,
    ylabel=None,
    use_log_x=True,
):
    if type(args_or_args_list) is not list:
        args_list = [args_or_args_list]
    else:
        args_list = args_or_args_list

    tasks = tasks or ["node", "link"]

    # task: node or link
    # dataset: cora, citeseer, pubmed, ppi
    # model: go, dp
    model_data_task_to_mean_list = dict()
    model_data_task_to_std_list = dict()
    for args in args_list:
        custom_key_prefix = "perf_against_{}".format(hparam_name)
        args_key = get_args_key(args)
        custom_key = "{}_{}".format(custom_key_prefix, args_key)

        task_to_mean_list, task_to_std_list = get_task_to_mean_and_std_per_against_hparam(
            hparam_list=hparam_list,
            hparam_name=hparam_name,
            args=args,
            num_total_runs=num_total_runs,
            tasks=tasks,
        )
        for task_tuple, mean_list in task_to_mean_list.items():
            task = task_tuple[0]
            std_list = task_to_std_list[task_tuple]
            model_data_task_to_mean_list[(args.m, args.dataset_name,
                                          task)] = mean_list
            model_data_task_to_std_list[(args.m, args.dataset_name,
                                         task)] = std_list

        if plot_individual and not plot_ablation:
            plot_line_with_std(
                tuple_to_mean_list=task_to_mean_list,
                tuple_to_std_list=task_to_std_list,
                x_label=xlabel or "Mixing Coefficient (Log)",
                y_label=ylabel or "Test Perf. ({}., AUC)".format(
                    "Acc" if args_or_args_list.dataset_name != "PPI" else "F1"
                ),
                name_label_list=["Task"],
                x_list=[float(np.log10(al))
                        for al in hparam_list] if use_log_x else hparam_list,
                hue="Task",
                style="Task",
                hue_order=[t.capitalize() for t in tasks],
                x_lim=(None, None),
                err_style="band",
                custom_key=custom_key,
                extension="png",
            )
        elif plot_individual and plot_ablation:
            sns.set_context("poster")
            plot_line_with_std(
                tuple_to_mean_list=task_to_mean_list,
                tuple_to_std_list=task_to_std_list,
                x_label=xlabel + f" ({args.dataset_name})"
                or "Mixing Coefficient (Log)",
                y_label=ylabel or "Test Perf.",
                name_label_list=["Task"],
                x_list=[float(np.log10(al))
                        for al in hparam_list] if use_log_x else hparam_list,
                hue="Task",
                style="Task",
                aspect=1.5,
                legend=False,
                err_style="band",
                custom_key="ablation_against_{}_{}".format(
                    hparam_name, custom_key),
                extension="pdf",
            )

    if not plot_individual and not plot_ablation:
        plot_line_with_std(
            tuple_to_mean_list=model_data_task_to_mean_list,
            tuple_to_std_list=model_data_task_to_std_list,
            x_label=xlabel or "Mixing Coefficient (Log)",
            y_label=ylabel or "Test Perf.",
            name_label_list=["GAT", "Dataset", "Task"],
            x_list=[float(np.log10(al))
                    for al in hparam_list] if use_log_x else hparam_list,
            hue="Dataset",
            style="Dataset",
            row="Task",
            col="GAT",
            hue_order=["Cora", "CiteSeer", "PubMed", "PPI"],
            aspect=1.6,
            err_style="band",
            custom_key="perf_against_{}_real_world_datasets".format(
                hparam_name),
            extension="pdf",
        )

        mt_data_to_mean_list = {
            (s_join(" & ", [m, t]), d): ml
            for (m, d, t), ml in model_data_task_to_mean_list.items()
        }
        mt_data_to_std_list = {
            (s_join(" & ", [m, t]), d): sl
            for (m, d, t), sl in model_data_task_to_std_list.items()
        }

        sns.set_context("poster", font_scale=1.2)
        plot_line_with_std(
            tuple_to_mean_list=mt_data_to_mean_list,
            tuple_to_std_list=mt_data_to_std_list,
            x_label=xlabel or "Mixing Coeff. (Log)",
            y_label=ylabel or "Test Perf.",
            name_label_list=["GAT & Task", "Dataset"],
            x_list=[float(np.log10(al))
                    for al in hparam_list] if use_log_x else hparam_list,
            hue="GAT & Task",
            palette=["darkred", "red", "dimgrey", "silver"],
            style="GAT & Task",
            col="Dataset",
            hue_order=["GO & Link", "DP & Link", "GO & Node", "DP & Node"],
            aspect=1.0,
            err_style="band",
            custom_key="perf_against_{}_real_world_datasets".format(
                hparam_name),
            extension="pdf",
        )
tf.set_random_seed(20130810)


# In[8]:


import matplotlib.pyplot as plt
import seaborn as sns


# In[9]:


get_ipython().magic('matplotlib inline')
sns.set_style('ticks', {'grid_color' : '0.9'})
sns.set_context('talk', font_scale=1.2)
sns.set_palette('gray')


# In[10]:


from keras.models import Sequential, load_model

from keras.layers import Dense, Activation, Dropout

from keras.losses import binary_crossentropy

from keras.optimizers import RMSprop, Adam

from keras.metrics import binary_accuracy
colors = colors.reshape(colors.shape[0] * colors.shape[1], 3) / 255.0

# Generate a distance matrix
D = squareform(pdist(colors))

# Compute metric MDS (embedding into a 2-dimensional space)
X = pymathtoolbox.compute_classical_mds(D=D, dim=2)

# Define constants for plot
FIG_SIZE = (8, 3)
IMAGE_FORMAT = "png"
DPI = 200

# Set style
sns.set()
sns.set_context()
plt.rcParams['font.sans-serif'] = ["Linux Biolinum O", "Linux Biolinum"]

# Draw plot
fig = plt.figure(figsize=FIG_SIZE, dpi=DPI)

ax = fig.add_subplot(1, 2, 1)
ax.set_title("Target Image")
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_xticks([])
ax.set_yticks([])
ax.imshow(image)

ax = fig.add_subplot(1, 2, 2)
ax.set_title("Pixel Colors Embedded into a 2D Space")
Пример #41
0
def plot_logs(experiments: List[Summary],
              smooth_factor: float = 0,
              share_legend: bool = True,
              ignore_metrics: Optional[Set[str]] = None,
              pretty_names: bool = False,
              include_metrics: Optional[Set[str]] = None) -> plt.Figure:
    """A function which will plot experiment histories for comparison viewing / analysis.

    Args:
        experiments: Experiment(s) to plot.
        smooth_factor: A non-negative float representing the magnitude of gaussian smoothing to apply (zero for none).
        share_legend: Whether to have one legend across all graphs (True) or one legend per graph (False).
        pretty_names: Whether to modify the metric names in graph titles (True) or leave them alone (False).
        ignore_metrics: Any keys to ignore during plotting.
        include_metrics: A whitelist of keys to include during plotting. If None then all will be included.

    Returns:
        The handle of the pyplot figure.
    """
    experiments = to_list(experiments)
    n_experiments = len(experiments)
    if n_experiments == 0:
        return plt.subplots(111)[0]

    ignore_keys = ignore_metrics or set()
    ignore_keys = to_set(ignore_keys)
    ignore_keys |= {'epoch'}
    include_keys = to_set(include_metrics)
    # TODO: epoch should be indicated on the axis (top x axis?). Problem - different epochs per experiment.
    # TODO: figure out how ignore_metrics should interact with mode

    metric_histories = defaultdict(_MetricGroup)  # metric: MetricGroup
    for idx, experiment in enumerate(experiments):
        history = experiment.history
        # Since python dicts remember insertion order, sort the history so that train mode is always plotted on bottom
        for mode, metrics in sorted(history.items(),
                                    key=lambda x: 0 if x[0] == 'train' else 1
                                    if x[0] == 'eval' else 2
                                    if x[0] == 'test' else 3
                                    if x[0] == 'infer' else 4):
            for metric, step_val in metrics.items():
                if len(step_val) == 0:
                    continue  # Ignore empty metrics
                if metric in ignore_keys:
                    continue
                if include_keys and metric not in include_keys:
                    continue
                metric_histories[metric].add(idx, mode, step_val)

    metric_list = list(sorted(metric_histories.keys()))
    if len(metric_list) == 0:
        return plt.subplots(111)[0]

    # If sharing legend and there is more than 1 plot, then dedicate 1 subplot for the legend
    share_legend = share_legend and (len(metric_list) > 1)
    n_plots = len(metric_list) + share_legend

    # map the metrics into an n x n grid, then remove any extra columns. Final grid will be n x m with m <= n
    n_rows = math.ceil(math.sqrt(n_plots))
    n_cols = math.ceil(n_plots / n_rows)
    metric_grid_location = {}
    nd1_metrics = []
    idx = 0
    for metric in metric_list:
        if metric_histories[metric].ndim() == 1:
            # Delay placement of the 1D plots until the end
            nd1_metrics.append(metric)
        else:
            metric_grid_location[metric] = (idx // n_cols, idx % n_cols)
            idx += 1
    for metric in nd1_metrics:
        metric_grid_location[metric] = (idx // n_cols, idx % n_cols)
        idx += 1

    sns.set_context('paper')
    fig, axs = plt.subplots(n_rows,
                            n_cols,
                            sharex='all',
                            figsize=(4 * n_cols, 2.8 * n_rows))

    # If only one row, need to re-format the axs object for consistency. Likewise for columns
    if n_rows == 1:
        axs = [axs]
        if n_cols == 1:
            axs = [axs]

    for metric in metric_grid_location.keys():
        axis = axs[metric_grid_location[metric][0]][
            metric_grid_location[metric][1]]
        if metric_histories[metric].ndim() == 1:
            axis.grid(linestyle='')
        else:
            axis.grid(linestyle='--')
            axis.ticklabel_format(axis='y', style='sci', scilimits=(-2, 3))
        axis.set_title(
            metric if not pretty_names else prettify_metric_name(metric),
            fontweight='bold')
        axis.spines['top'].set_visible(False)
        axis.spines['right'].set_visible(False)
        axis.spines['bottom'].set_visible(False)
        axis.spines['left'].set_visible(False)
        axis.tick_params(bottom=False, left=False)

    for i in range(n_cols):
        axs[n_rows - 1][i].set_xlabel('Steps')

    # some of the columns in the last row might be unused, so disable them
    last_column_idx = n_cols - (n_rows * n_cols - (n_plots - share_legend)) - 1
    for i in range(last_column_idx + 1, n_cols):
        axs[n_rows - 1][i].axis('off')
        axs[n_rows - 2][i].set_xlabel('Steps')
        axs[n_rows - 2][i].xaxis.set_tick_params(which='both',
                                                 labelbottom=True)

    # the 1D metrics don't need x axis, so move them up, starting with the last in case multiple rows of them
    for metric in reversed(nd1_metrics):
        row = metric_grid_location[metric][0]
        col = metric_grid_location[metric][1]
        axs[row][col].axis('off')
        if row > 0:
            axs[row - 1][col].set_xlabel('Steps')
            axs[row - 1][col].xaxis.set_tick_params(which='both',
                                                    labelbottom=True)

    colors = sns.hls_palette(
        n_colors=n_experiments,
        s=0.95) if n_experiments > 10 else sns.color_palette("colorblind")
    color_offset = defaultdict(lambda: 0)
    # If there is only 1 experiment, we will use alternate colors based on mode
    if n_experiments == 1:
        color_offset['eval'] = 1
        color_offset['test'] = 2
        color_offset['infer'] = 3

    handles = []
    labels = []
    has_label = defaultdict(lambda: defaultdict(lambda: defaultdict(
        lambda: False)))  # exp_id : {mode: {type: True}}
    ax_text = defaultdict(lambda:
                          (0.0, 0.9))  # Where to put the text on a given axis
    for exp_idx, experiment in enumerate(experiments):
        for metric, group in metric_histories.items():
            axis = axs[metric_grid_location[metric][0]][
                metric_grid_location[metric][1]]
            if group.ndim() == 1:
                # Single value
                for mode in group.modes(exp_idx):
                    ax_id = id(axis)
                    prefix = f"{experiment.name} ({mode})" if n_experiments > 1 else f"{mode}"
                    axis.text(ax_text[ax_id][0],
                              ax_text[ax_id][1],
                              f"{prefix}: {group.get_val(exp_idx, mode)}",
                              color=colors[exp_idx + color_offset[mode]],
                              transform=axis.transAxes)
                    ax_text[ax_id] = (ax_text[ax_id][0],
                                      ax_text[ax_id][1] - 0.1)
                    if ax_text[ax_id][1] < 0:
                        ax_text[ax_id] = (ax_text[ax_id][0] + 0.5, 0.9)
            elif group.ndim() == 2:
                for mode, data in group[exp_idx].items():
                    title = f"{experiment.name} ({mode})" if n_experiments > 1 else f"{mode}"
                    if data.shape[0] < 2:
                        # This particular mode only has a single data point, so need to draw a shape instead of a line
                        xy = (data[0][0], data[0][1])
                        if mode == 'train':
                            style = MarkerStyle(marker='o', fillstyle='full')
                        elif mode == 'eval':
                            style = MarkerStyle(marker='v', fillstyle='full')
                        elif mode == 'test':
                            style = MarkerStyle(marker='*', fillstyle='full')
                        else:
                            style = MarkerStyle(marker='s', fillstyle='full')
                        s = axis.scatter(
                            xy[0],
                            xy[1],
                            s=40,
                            c=[colors[exp_idx + color_offset[mode]]],
                            marker=style,
                            linewidth=1.0,
                            edgecolors='black',
                            zorder=3
                        )  # zorder to put markers on top of line segments
                        if not has_label[exp_idx][mode]['patch']:
                            labels.append(title)
                            handles.append(s)
                            has_label[exp_idx][mode]['patch'] = True
                    else:
                        # We can draw a line
                        y = data[:,
                                 1] if smooth_factor == 0 else gaussian_filter1d(
                                     data[:, 1], sigma=smooth_factor)
                        ln = axis.plot(
                            data[:, 0],
                            y,
                            color=colors[exp_idx + color_offset[mode]],
                            label=title,
                            linewidth=1.5,
                            linestyle='solid'
                            if mode == 'train' else 'dashed' if mode == 'eval'
                            else 'dotted' if mode == 'test' else 'dashdot')
                        if not has_label[exp_idx][mode]['line']:
                            labels.append(title)
                            handles.append(ln[0])
                            has_label[exp_idx][mode]['line'] = True
            else:
                # Some kind of image or matrix. Not implemented yet.
                pass

    plt.tight_layout()

    if labels:
        if share_legend:
            axs[n_rows - 1][last_column_idx + 1].legend(
                handles,
                labels,
                loc='center',
                fontsize='large' if len(handles) <= 6 else
                'medium' if len(handles) <= 8 else 'small')
        else:
            for i in range(n_rows):
                for j in range(n_cols):
                    if i == n_rows - 1 and j > last_column_idx:
                        break
                    axs[i][j].legend(loc='best', fontsize='small')
    return fig
Пример #42
0
    N_iter = 1000000
    lr_l = list(range(N_iter))
    for i in range(N_iter):
        scheduler.step()
        current_lr = optimizer.param_groups[0]["lr"]
        lr_l[i] = current_lr

    import matplotlib as mpl  # type: ignore
    from matplotlib import pyplot as plt
    import matplotlib.ticker as mtick  # type: ignore

    mpl.style.use("default")
    import seaborn  # type: ignore

    seaborn.set(style="whitegrid")
    seaborn.set_context("paper")

    plt.figure(1)
    plt.subplot(111)
    plt.ticklabel_format(style="sci", axis="x", scilimits=(0, 0))
    plt.title("Title", fontsize=16, color="k")
    plt.plot(list(range(N_iter)),
             lr_l,
             linewidth=1.5,
             label="learning rate scheme")
    legend = plt.legend(loc="upper right", shadow=False)
    ax = plt.gca()
    labels = ax.get_xticks().tolist()
    for k, v in enumerate(labels):
        labels[k] = str(int(v / 1000)) + "K"
    ax.set_xticklabels(labels)
Пример #43
0
#!/usr/bin/env python
import argparse
import logging
import sys
from subprocess import DEVNULL, run

import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import numpy as np
from matplotlib.backends.backend_pdf import PdfPages
from scipy.stats import ttest_ind

sns.set(style="whitegrid")
sns.set_context("paper", font_scale=2.0, rc={"lines.linewidth": 2.25})


# Configure logging
logging.basicConfig(
    stream=sys.stdout,
    format="[%(asctime)s][%(levelname)s] %(name)s:%(lineno)s - %(message)s",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

FIGSIZE = (7, 6)

def ttests(a, b, c):
    res = ttest_ind(a, b)
    logger.info(f"A vs. B: {res}")
Пример #44
0
    data = pd.DataFrame({
        'Time': time,
        'Direction': direction,
        'NumAtoms': length
    })
    data.to_pickle(output_filename)


if __name__ == '__main__':

    test_length(interval=(100, 100, 3000),
                output_filename='Data/RMSDTime.pkl',
                num_measurements=10)

    data = pd.read_pickle('Data/RMSDTime.pkl')
    sea.set_style("whitegrid")
    sea.set_context("paper", font_scale=1.5, rc={"lines.linewidth": 1.5})
    g1 = sea.relplot(x="NumAtoms",
                     y="Time",
                     hue='Direction',
                     kind="line",
                     style="Direction",
                     height=6,
                     aspect=1.3,
                     markers=True,
                     data=data)
    plt.ylabel("Time, ms")
    plt.xlabel("Number of atoms")
    sea.despine()
    # plt.show()
    plt.savefig("Fig/RMSDTime.png")
#!/usr/bin/env python
# coding: utf-8
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib import cm
import matplotlib.artist as martist
from matplotlib.offsetbox import AnchoredText
import seaborn as sns

sns.set_style('whitegrid')
sns.set_context("paper")

os.chdir('/Users/pauline/Documents/Python')
df = pd.read_csv("Tab-Bathy.csv")

fig, axes = plt.subplots(6, 1, figsize=(12,14), sharex=True, sharey=False, dpi=300)
plt.suptitle('Geomorphology of the Mariana Trench: cross-section of the 25 bathymetric profiles',
             x=0.54, y=.95, fontsize=12)

def add_at(axes, t, loc=2):
    fp = dict(size=11)
    _at = AnchoredText(t, loc=loc, prop=fp)
    axes.add_artist(_at)
    return _at

# subplot 1
fig = df.plot(x='observ', y=['profile1', 'profile2', 'profile3', 'profile4', 'profile5'],
              linestyle='-', linewidth='.8', cmap=cm.Set1, ax=axes[0])
axes[0].legend(loc='upper right', bbox_to_anchor=(1.12, 1.1),
from __future__ import absolute_import, division, print_function

from copy import deepcopy
import json
import glob
import os
import time

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

sns.set_style('white')
sns.set_context('poster')

COLORS = [
    '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c', '#98df8a',
    '#d62728', '#ff9896', '#9467bd', '#c5b0d5', '#8c564b', '#c49c94',
    '#e377c2', '#f7b6d2', '#7f7f7f', '#c7c7c7', '#bcbd22', '#dbdb8d',
    '#17becf', '#9edae5'
]


def calc_iou_individual(pred_box, gt_box):
    """Calculate IoU of single predicted and ground truth box

    Args:
        pred_box (list of floats): location of predicted object as
            [xmin, ymin, xmax, ymax]
Пример #47
0
from sklearn.manifold import TSNE
from bokeh.plotting import figure, output_file, show, save
from bokeh.models import HoverTool
import matplotlib.pyplot as plt
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import TruncatedSVD
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns
import matplotlib.patheffects as PathEffects
import pymysql

sns.set_style('darkgrid')

sns.set_palette('muted')

sns.set_context("notebook", font_scale=1.5, rc={"lines.linewidth": 2.5})

n_features = 1000
n_samples = 2000
n_topics = 2
n_top_words = 20


def loadCsv(filename):
    lines = csv.reader(open(filename, "rb"))
    dataset = list(lines)
    for i in range(len(dataset)):
        dataset[i] = [x for x in dataset[i]]
    return dataset

Пример #48
0
def plotModelOutput(df,inputs,eqTime,eqTemp,popStats,save,saveName,dimVar, exp=1):
    sns.set_style('darkgrid')
    sns.set_context('poster',rc={'font.size': 30.0,
     'axes.labelsize': 26.0,
     'axes.titlesize': 24.0,
     'xtick.labelsize': 26.0,
     'ytick.labelsize': 26.0,
     'legend.fontsize': 22.0})
    locale.setlocale(locale.LC_ALL, '')
    timer = np.asarray(df['time'])
    temp = np.asarray(df['temp'])
    pop = np.asarray(df['pop'])
    pco2=np.asarray(df['pco2'])
    finalTemp = df['temp'][df.index[-1]]
    timer = timer/60/60/24/365.25; #convert seconds to years
    timer = timer+1820;
    pop = pop/1000
    dN = np.diff(pop)
    pco2 = pco2*10**6
#--------------------------------------------------------------Phase Plots----------------------------------------------------------------------------------------
    sns.set_context('poster',rc={'font.size': 18.0,
     'axes.labelsize': 26.0,
     'axes.titlesize': 24.0,
     'xtick.labelsize': 25.0,
     'ytick.labelsize': 25.0,
     'legend.fontsize': 14.0})
    fig, ax = plt.subplots(figsize=(10,5),dpi=200) #set up figure
    fig.suptitle("Distance: " + str(round(inputs[0],4)) +" AU,  $Carrying\ Capacity$: " + str( '{:,}'.format(round(inputs[1]/1000)) ) +" billion ppl"+",    $\gamma:$ " +str(round(dimVar,3)),x=.46,fontsize=19)
    ax.set_ylim(min(min(temp),eqTemp)-(5/100)*(max(temp)-min(temp)),finalTemp) 
    line = ax.scatter(pop,temp,c=pco2,cmap='jet')
    cbar = fig.colorbar(line)
    cbar.set_label(r'$pCO_{2}\ (ppm)$', size=19)
    ax.set_xlabel('Population (billion)',fontsize=17)
    #horozontal lines
    ax.axhline(y=eqTemp,c='b',label='$T_{eq}=$'+str(round(eqTemp))+" K")
    ax.axhline(y=eqTemp+inputs[3],c='springgreen',label='$T_{eq}+\Delta T=$'+str(round(eqTemp+inputs[3]))+" K")
    ax.axhline(y=popStats['initPop']/1000,c='orangered',linestyle="--",label='$N_{0}=$'+str(round(popStats['initPop'],1))+" million")
    #vertical lines
    ax.axvline(x=popStats['maxPop'],ms=8,c='b',linestyle='--',label='$N_{peak}=$'+str(round(popStats['maxPop'],1))+" billion")
    ax.axvline(x=popStats['halfPop'],ms=8,c='springgreen',linestyle='--',label='$N_{1/2}=$'+str(round(popStats['halfPop'],1))+" billion")       
    ax.axvline(x=popStats['finalPop'],ms=8,c='orangered',linestyle='--',label='$N_{final}=$'+str(round(popStats['finalPop'],1))+" billion")       
    ax.set_ylabel('Temperature (K)',fontsize=17)
    ax.legend(loc='best')
    plt.gcf().subplots_adjust(bottom=0.175)
   # plt.tight_layout()
    if save[0]: 
        if(exp==1): plt.savefig("../plotsPhase_exp1/"+str(saveName)+".png")
        if(exp==2): plt.savefig("../plotsPhase_exp2/"+str(saveName)+".png")
    if save[1]: plt.show()
    plt.close('all') 
#----------Normal Plots----------------------------------------------------------------------------------------
    fig, (ax2, ax1) = plt.subplots(2,sharex=True,figsize=(20,10),dpi=200) #set up figure, share the x axis
    fig.suptitle("       Distance: " + str(round(inputs[0],4)) +" AU,  $Carrying\ Capacity$: " + str( '{:,}'.format(round(inputs[1]/1000)) ) +" billion ppl,  $\gamma:$ " +str(round(dimVar,3)),x=.40, fontsize=32)
     #plot time vs temp (K)
    line1 = ax1.scatter(timer,temp,c=pco2,cmap='jet')
    ax1.set_title('Temperature vs Time')
    ax1.set(ylabel='Temperature (K)',xlabel='Time (years)')
    color='black'
    linestyle='--'
    alpha=.5
    ax1.set_xlim(min(timer),max(timer))
    ax1.set_ylim(min(min(temp),eqTemp)-(5/100)*(max(temp)-min(temp)),max(temp)+(5/100)*(max(temp)-min(temp)))
    ax1.set_yticks(np.linspace(min(min(temp),eqTemp)-(5/100)*(max(temp)-min(temp)),max(temp)+(5/100)*(max(temp)-min(temp)),4))
    
    sns.set_style('darkgrid')

    #plot time vs pop    
    line2 = ax2.scatter(timer,pop,c=pco2,cmap='jet')
    
    ax2.set(ylabel='Population (billions)')
    ax2.set_title("Population vs Time")
    ax2.set_yticks(np.linspace(min(pop),popStats['maxPopPlot'],4))
    sns.set_palette('colorblind') 
    #horizontal lines
    ax2.axhline(y=popStats['maxPop'],c='b',label='$N_{peak}=$'+str(round(popStats['maxPop'],1))+" billion")
    ax2.axhline(y=popStats['halfPop'],c='springgreen',label='$N_{1/2}=$'+str(round(popStats['halfPop'],1))+" billion")       
    ax2.axhline(y=popStats['finalPop'],c='orangered',label='$N_{final}=$'+str(round(popStats['finalPop'],1))+" billion")  
    ax2.axhline(y=popStats['anthroPop']/1000,c='b',linestyle="--",label='$N_{A}=$'+str(round(popStats['anthroPop']/1000,1))+" billion")
 #   ax2.axhline(y=popStats['initPop']/1000,c='springgreen',linestyle="--",label='$N_{0}=$'+str(round(popStats["initPop"],1))+" million")
    ax1.axhline(y=eqTemp,c='b',label='$T_{eq}=$'+str(round(eqTemp))+" K")
    ax1.axhline(y=eqTemp+inputs[3],c="springgreen",label='$T_{eq}+\Delta T=$'+str(round(eqTemp+inputs[3]))+" K")
#    ax1.axhline(y=eqTemp+2*inputs[3],c='orangered',label='$T_{eq}+2\Delta T=$'+str(round(eqTemp+2*inputs[3]))+" K")
       #vertical lines
#   ax2.axvline(x=(popStats['LhalfTime']+1820),linestyle='--',c=(0,0,.7))
#    ax1.axvline(x=(popStats['LhalfTime']+1820),linestyle='--',c=(0,0,.7),label="$t_{1/2}^{-}=$"+str(int((popStats['LhalfTime']+1820))))
    ax2.axvline(x=(popStats['maxTime']+1820),c='b',linestyle='--') 
    ax1.axvline(x=(popStats['maxTime']+1820),c='b',linestyle='--',label="$t_{peak}=$"+str(int((popStats['maxTime']+1820))))
#    ax2.axvline(x=(popStats['UhalfTime']+1820),linestyle='--',c="orangered")
#    ax1.axvline(x=(popStats['UhalfTime']+1820),linestyle='--',c="orangered",label="$t_{1/2}=$"+str(int((popStats['UhalfTime']+1820)))) 
    
    ax2.set_xlim(min(timer),max(timer))
    ax2.set_ylim(min(pop)- min(pop)*(2/100),popStats['maxPopPlot'])
    cbar = fig.colorbar(line2,label='pCO2 (ppm)',ax=[ax1,ax2])
    cbar.set_label(r'$pCO_{2}\ (ppm)$', size=30)
    ax2.legend(loc='best', prop={'size': 23})
    ax1.legend(loc='lower right', prop={'size': 25})
    if save[0]:
        if(exp==1): plt.savefig("../plots_exp1/"+str(saveName)+".png")
        if(exp==2): plt.savefig("../plots_exp2/"+str(saveName)+".png")
    if save[1]: plt.show()
    plt.close('all')
Пример #49
0
from pathlib import Path
import pandas as pd
import salem

import matplotlib.pyplot as plt
from matplotlib import ticker
import seaborn as sns

from _const_.default import *

sns.set_context('talk')
plt.style.use('seaborn-whitegrid')

VARS2 = ['tmean', 'precip']
RF_DATE_RANGE = slice(*BL_PERIOD)
PROJ_MID_DATE_RANGE = slice(*PROJ_PERIODS['mid'])

in_shp_dir = Path('input/shp/basins')
in_xls_dir = Path('output/xls/rcm')

out_img_dir = Path('output/img/boxplot')
out_img_dir.mkdir(parents=True, exist_ok=True)

out_stat_dir = Path('output/stat/boxplot')
out_stat_dir.mkdir(parents=True, exist_ok=True)

in_shps = list(in_shp_dir.glob('*/*.shp'))
in_shps = in_shps[1:]


stat_df = []
Пример #50
0
    def draw_swarmplot(self, ax, inlier, kws):
        """Plot the data."""
        s = kws.pop("s")

        centers = []
        swarms = []

        # Set the categorical axes limits here for the swarm math
        if self.orient == "v":
            ax.set_xlim(-.5, len(self.plot_data) - .5)
        else:
            ax.set_ylim(-.5, len(self.plot_data) - .5)

        # Plot each swarm
        for i, group_data in enumerate(self.plot_data):

            if self.plot_hues is None or not self.split:

                width = self.width

                if self.hue_names is None:
                    hue_mask = np.ones(group_data.size, np.bool)
                else:
                    hue_mask = np.array(
                        [h in self.hue_names for h in self.plot_hues[i]],
                        np.bool)
                    # Broken on older numpys
                    # hue_mask = np.in1d(self.plot_hues[i], self.hue_names)

                swarm_data = group_data[hue_mask]

                # Sort the points for the beeswarm algorithm
                sorter = np.argsort(swarm_data)
                swarm_data = swarm_data[sorter]
                point_colors = self.point_colors[i][hue_mask][sorter]

                # Plot the points in centered positions
                cat_pos = np.ones(swarm_data.size) * i
                kws.update(c=point_colors)
                if self.orient == "v":
                    sns.set_context(rc={'lines.markeredgewidth': 1})
                    ax.scatter(cat_pos[~inlier[i]],
                               swarm_data[~inlier[i]],
                               marker=r'$\mathbf{\times}$',
                               s=5 * s,
                               **kws)
                    sns.set_context(rc={'lines.markeredgewidth': 0})
                    points = ax.scatter(cat_pos[inlier[i]],
                                        swarm_data[inlier[i]],
                                        s=s,
                                        **kws)

                else:
                    points = ax.scatter(swarm_data, cat_pos, s=s, **kws)

                centers.append(i)
                swarms.append(points)

            else:
                offsets = self.hue_offsets
                width = self.nested_width

                for j, hue_level in enumerate(self.hue_names):
                    hue_mask = self.plot_hues[i] == hue_level
                    swarm_data = group_data[hue_mask]

                    # Sort the points for the beeswarm algorithm
                    sorter = np.argsort(swarm_data)
                    swarm_data = swarm_data[sorter]
                    point_colors = self.point_colors[i][hue_mask][sorter]

                    # Plot the points in centered positions
                    center = i + offsets[j]
                    cat_pos = np.ones(swarm_data.size) * center
                    kws.update(c=point_colors)
                    if self.orient == "v":
                        points = ax.scatter(cat_pos, swarm_data, s=s, **kws)
                    else:
                        points = ax.scatter(swarm_data, cat_pos, s=s, **kws)

                    centers.append(center)
                    swarms.append(points)

        # Update the position of each point on the categorical axis
        # Do this after plotting so that the numerical axis limits are correct
        for center, swarm in zip(centers, swarms):
            if swarm.get_offsets().size:
                self.swarm_points(ax, swarm, center, width, s, **kws)
Пример #51
0
def load_and_plot_res(testName,
                      title,
                      ax,
                      fontscale=1.0,
                      folderName=None,
                      methods=None,
                      leg=False,
                      linewidth=2,
                      ylab=None,
                      xlab=None):

    # load the right file
    if type(testName) == tuple:  # convert test name to string

        if testName[0] == True:
            testNameString = "Linear"
        else:
            testNameString = "Nonlinear"

        testName = testNameString + testName[1].lower()

    if folderName is None:
        filee = "tests/" + testName + ".p"
    else:
        filee = "tests/" + folderName + "/" + testName + ".p"

    res, parameters = loadRes(filee)

    ns = parameters["ns"]
    print("k = ", parameters.get('k'))

    if "KCIT_AND" or "RCIT_AND" in methods:  # add results for the kernel mehthods from different folder
        kernel_res, kernel_par = loadRes("tests/kernel_tests/" + testName +
                                         ".p")
        print("ntests (kernel) = ", len(kernel_res['RCIT_AND'][2000]['HD']))
        res.update(kernel_res)

    HDs = np.zeros((len(methods), len(ns)))
    SDs = np.zeros((len(methods), len(ns)))

    ii = 0

    for method in methods:

        jj = 0

        for n in ns:
            hds = res[method][n]["HD"]

            HDs[ii, jj] = np.mean(hds)
            SDs[ii, jj] = np.std(hds) / np.sqrt(
                len(hds))  # standard error of the mean

            jj += 1

        ii += 1

    print("ntests = ", len(hds))

    x = range(0, len(ns))

    if methods == knn_comp_methods:
        linestyles = 4 * ["-"] + 4 * ["--"]
        sns.set(style='ticks',
                palette=sns.color_palette("Set2", 4),
                font_scale=fontscale)

    else:
        linestyles = ['-', '--']
        linestyles = (int(len(methods) / len(linestyles)) + 1) * linestyles
        #palette = sns.color_palette("husl", 7)
        palette = sns.color_palette("Set2", 10)
        sns.set(style='ticks', palette=palette, font_scale=fontscale)
    sns.set_context(font_scale=fontscale)

    for i in range(0, len(methods)):
        lab = methods[i]
        ax.errorbar(x,
                    HDs[i, :],
                    SDs[i, :],
                    ls=linestyles[i],
                    label=lab,
                    marker='o',
                    linewidth=linewidth)

    ax.set_xticks(range(0, len(ns)))
    ax.set_xticklabels(ns)
    ax.grid()

    if xlab is not None:
        ax.set_xlabel(xlab)

    if ylab is not None:
        ax.set_ylabel(ylab)

    if title is not None:
        ax.set_title(title)

    if leg is True:
        # Shrink current axis by 20%
        box = ax.get_position()
        ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])

        # Put a legend to the right of the current axis
        ax.legend(loc='center left',
                  bbox_to_anchor=(1, 0.5),
                  prop={'size': 12})

    plt.tight_layout()
Пример #52
0
def main():
    sns.set_style("whitegrid")
    sns.set_context("notebook")
    inifile= "/Users/Bodangles/Documents/Python/RadarDataSim/Testdata/PFISRExample.pickle"
    (sensdict,simparams) = readconfigfile(inifile)
    simdtype = simparams['dtype']
    sumrule = simparams['SUMRULE']
    npts = simparams['numpoints']
    amb_dict = simparams['amb_dict']
    # for spectrum
    ISS2 = ISRSpectrum(centerFrequency = 440.2*1e6, bMag = 0.4e-4, nspec=npts, sampfreq=sensdict['fs'],dFlag=True)
    ti = 2e3
    te = 2e3
    Ne = 1e11
    Ni = 1e11


    datablock90 = sp.array([[Ni,ti],[Ne,te]])
    species = simparams['species']

    (omega,specorig,rcs) = ISS2.getspecsep(datablock90, species,rcsflag = True)

    cur_filt = sp.sqrt(scfft.ifftshift(specorig*npts*npts*rcs/specorig.sum()))
    #for data
    Nrep = 10000
    pulse = sp.ones(14)
    lp_pnts = len(pulse)
    N_samps = 100
    minrg = -sumrule[0].min()
    maxrg = N_samps+lp_pnts-sumrule[1].max()
    Nrng2 = maxrg-minrg;
    out_data = sp.zeros((Nrep,N_samps+lp_pnts),dtype=simdtype)
    samp_num = sp.arange(lp_pnts)
    for isamp in range(N_samps):
        cur_pnts = samp_num+isamp
        cur_pulse_data = MakePulseDataRep(pulse,cur_filt,rep=Nrep)
        out_data[:,cur_pnts] = cur_pulse_data+out_data[:,cur_pnts]

    lagsData = CenteredLagProduct(out_data,numtype=simdtype,pulse =pulse)
    lagsData=lagsData/Nrep # divide out the number of pulses summed
    Nlags = lagsData.shape[-1]
    lagsDatasum = sp.zeros((Nrng2,Nlags),dtype=lagsData.dtype)
    for irngnew,irng in enumerate(sp.arange(minrg,maxrg)):
        for ilag in range(Nlags):
            lagsDatasum[irngnew,ilag] = lagsData[irng+sumrule[0,ilag]:irng+sumrule[1,ilag]+1,ilag].mean(axis=0)
    lagsDatasum=lagsDatasum/lp_pnts # divide out the pulse length
    (tau,acf) = spect2acf(omega,specorig)

    # apply ambiguity function
    tauint = amb_dict['Delay']
    acfinterp = sp.zeros(len(tauint),dtype=simdtype)

    acfinterp.real =spinterp.interp1d(tau,acf.real,bounds_error=0)(tauint)
    acfinterp.imag =spinterp.interp1d(tau,acf.imag,bounds_error=0)(tauint)
    # Apply the lag ambiguity function to the data
    guess_acf = sp.zeros(amb_dict['Wlag'].shape[0],dtype=sp.complex128)
    for i in range(amb_dict['Wlag'].shape[0]):
        guess_acf[i] = sp.sum(acfinterp*amb_dict['Wlag'][i])

#    pdb.set_trace()
    guess_acf = guess_acf*rcs/guess_acf[0].real

    # fit to spectrums
    spec_interm = scfft.fftshift(scfft.fft(guess_acf,n=npts))
    spec_final = spec_interm.real
    allspecs = scfft.fftshift(scfft.fft(lagsDatasum,n=len(spec_final),axis=-1),axes=-1)
#    allspecs = scfft.fftshift(scfft.fft(lagsDatasum,n=npts,axis=-1),axes=-1)
    fig = plt.figure()
    plt.plot(omega,spec_final.real,label='In',linewidth=5)
    plt.hold(True)
    plt.plot(omega,allspecs[40].real,label='Out',linewidth=5)
    plt.axis((omega.min(),omega.max(),0.0,2e11))
    plt.show(False)
Пример #53
0
import matplotlib.pyplot as plt
import seaborn as sns

from qsforex.settings import OUTPUT_RESULTS_DIR

if __name__ == "__main__":
    """
    A simple script to plot the balance of the portfolio, or
    "equity curve", as a function of time.

    It requires OUTPUT_RESULTS_DIR to be set in the project
    settings.
    """
    sns.set_palette("deep", desat=.6)
    sns.set_context(rc={"figure.figsize": (8, 4)})

    equity_file = os.path.join(OUTPUT_RESULTS_DIR, "equity.csv")
    equity = pd.io.parsers.read_csv(equity_file,
                                    parse_dates=True,
                                    header=0,
                                    index_col=0)

    # Plot three charts: Equity curve, period returns, drawdowns
    fig = plt.figure()
    fig.patch.set_facecolor('white')  # Set the outer colour to white

    # Plot the equity curve
    ax1 = fig.add_subplot(311, ylabel='Portfolio value')
    equity["Equity"].plot(ax=ax1, color=sns.color_palette()[0])
Пример #54
0
a function of frequency.

NOTE: Also see 10_fit_alpha2.py

"""
import matplotlib as mpl
import seaborn as sns
import tables as tb
from leda_cal.leda_cal import *
from leda_cal.skymodel import *
from leda_cal.dpflgr import *
from leda_cal.useful import fit_poly, closest, trim, fourier_fit, poly_fit, rebin

from lmfit import minimize, Parameters, report_fit
sns.set_style('ticks')
sns.set_context("paper", font_scale=1.5)


def residual(params, model, x, data, errs=1.0):
    model_vals = model(params, x)
    return (data - model_vals) / errs


def model(params, x):
    amp = params['amp'].value
    alpha = params['alpha'].value

    model = amp * np.power(
        x, alpha)  #+ off #+ a_s * np.sin(theta * x) #+ a_c * np.cos(theta * x)
    #print mo
Пример #55
0
Email: [email protected]
Web: www.mnguenther.com
"""

from __future__ import print_function, division, absolute_import

#::: plotting settings
import seaborn as sns
sns.set(context='paper',
        style='ticks',
        palette='deep',
        font='sans-serif',
        font_scale=1.5,
        color_codes=True)
sns.set_style({"xtick.direction": "in", "ytick.direction": "in"})
sns.set_context(rc={'lines.markeredgewidth': 1})

#::: modules
import numpy as np
import matplotlib.pyplot as plt
import os, sys
import gzip
try:
    import cPickle as pickle
except:
    import pickle
from dynesty import utils as dyutils
from tqdm import tqdm


def ns_plot_bayes_factors(run_names, labels=None, return_dlogZ=False):
Пример #56
0
# python 3.5
# -*- coding: utf-8 -*-
""" Bioluminescence data from TMZ-treated mice, fitting the growth rate of the
post treatment slope in order to determine whether growth there is faster than
during the primaray. """

import numpy as np
import matplotlib.pyplot as plt
import pylab
import seaborn
from lmfit import *

plt.rc('text', usetex=True)
seaborn.set_context('talk')
seaborn.set_style('ticks')
pylab.ion()
""" starting values for fitting """
# exponential model
l0s = 0.3  # growth rate

# start value
N0s = 3 * 10000

# relative error
error_s = 0.33
""" data
structure: (animal identifier, [timepoints], [raw BLI measurement data])
"""

# Showing here only data after the lowest point has been passed as the goal is
# to fit the growth rates of the relapse
======================================================

This example shows the effect of the shrinkage factor used to generate the
smoothed bootstrap using the
:class:`~imblearn.over_sampling.RandomOverSampler`.
"""

# Authors: Guillaume Lemaitre <*****@*****.**>
# License: MIT

# %%
print(__doc__)

import seaborn as sns

sns.set_context("poster")

# %%
# First, we will generate a toy classification dataset with only few samples.
# The ratio between the classes will be imbalanced.
from collections import Counter
from sklearn.datasets import make_classification

X, y = make_classification(
    n_samples=100,
    n_features=2,
    n_redundant=0,
    weights=[0.1, 0.9],
    random_state=0,
)
Counter(y)
Пример #58
0
import seaborn as sns

# favorite Seaborn settings for notebooks
rc = {
    'lines.linewidth': 2,
    'axes.labelsize': 16,
    'axes.titlesize': 18,
    'axes.facecolor': 'F4F3F6',
    'axes.edgecolor': '000000',
    'axes.linewidth': 1.2,
    'xtick.labelsize': 13,
    'ytick.labelsize': 13,
    'grid.linestyle': ':',
    'grid.color': 'a6a6a6'
}
sns.set_context('notebook', rc=rc)
sns.set_style('darkgrid', rc=rc)
sns.set_palette("deep", color_codes=True)

# Import the project utils
import sys
sys.path.insert(0, '../../analysis/')

import mwc_induction_utils_processing as mwc
#===============================================================================
# define variables to use over the script
date = 20170622
username = '******'
#run = 'r1'

# list the directory with the data
Пример #59
0
import matplotlib.pyplot as plt
import igraph as ig
import seaborn as sns
# import networkx as nx
from matplotlib.ticker import ScalarFormatter

import matplotlib
matplotlib.rcParams['figure.figsize'] = (6.0, 4.0)

DEL_THRESHOLD = 0

sns.set_context(
    "talk",
    font_scale=1,
    rc={
        "lines.linewidth": 2.5,
        "text.usetex": False,
        "font.family": 'serif',
        "font.serif": ['Palatino'],
        "font.size": 16
    })

sns.set_style('white')

################
# MISC
################

get_full_path = lambda p: os.path.join(os.path.dirname(os.path.abspath(__file__)), p)

def ig_to_nx(ig_graph, directed=False, nodes=None):
    g = nx.DiGraph() if directed else nx.Graph()
Пример #60
0
    def plot_results(self, filename=None):
        """
        Plot the Tearsheet
        """
        rc = {
            'lines.linewidth': 1.0,
            'axes.facecolor': '0.995',
            'figure.facecolor': '0.97',
            'font.family': 'serif',
            'font.serif': 'Ubuntu',
            'font.monospace': 'Ubuntu Mono',
            'font.size': 10,
            'axes.labelsize': 10,
            'axes.labelweight': 'bold',
            'axes.titlesize': 10,
            'xtick.labelsize': 8,
            'ytick.labelsize': 8,
            'legend.fontsize': 10,
            'figure.titlesize': 12
        }
        sns.set_context(rc)
        sns.set_style("whitegrid")
        sns.set_palette("deep", desat=.6)

        if self.rolling_sharpe:
            offset_index = 1
        else:
            offset_index = 0
        vertical_sections = 6 + offset_index
        fig = plt.figure(figsize=(10, vertical_sections * 5.5))

        fig.suptitle(self.title, y=0.94, weight='bold')
        gs = gridspec.GridSpec(vertical_sections, 3, wspace=0.25, hspace=1)

        stats = self.get_results()
        ax_equity = plt.subplot(gs[:2, :])
        if self.rolling_sharpe:
            ax_sharpe = plt.subplot(gs[2, :])
        ax_drawdown = plt.subplot(gs[2 + offset_index, :])
        ax_monthly_returns = plt.subplot(gs[3 + offset_index, :2])
        ax_yearly_returns = plt.subplot(gs[3 + offset_index, 2])
        ax_txt_curve = plt.subplot(gs[4 + offset_index:, 0])
        ax_txt_trade = plt.subplot(gs[4 + offset_index:, 1])
        ax_txt_time = plt.subplot(gs[4 + offset_index:, 2])

        self._plot_equity(stats, ax=ax_equity)
        if self.rolling_sharpe:
            self._plot_rolling_sharpe(stats, ax=ax_sharpe)
        self._plot_drawdown(stats, ax=ax_drawdown)
        self._plot_monthly_returns(stats, ax=ax_monthly_returns)
        self._plot_yearly_returns(stats, ax=ax_yearly_returns)
        self._plot_txt_curve(stats, ax=ax_txt_curve)
        self._plot_txt_trade(stats, ax=ax_txt_trade)
        self._plot_txt_time(stats, ax=ax_txt_time)

        # Plot the figure
        # plt.show(block=False)
        plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1)
        plt.show()

        if filename is not None:
            fig.savefig(filename, dpi=150, bbox_inches='tight')