示例#1
0
def make_movie():

    output_list = np.arange(0, 101, 1)
    #    output_list = np.arange(30, 63)
    pool = mp.Pool(mp.cpu_count())
    print("Number of processors: ", mp.cpu_count())

    pool.map(make_movie_plots, [output for output in output_list])
    pool.close()

    cwd = os.getcwd()
    os.chdir(plot_folder)
    os.system(
        'ffmpeg -framerate 9 -pattern_type glob -i "*.png" -c:v mpeg4 -pix_fmt yuv420p -q:v 3 %s.mov'
        % field)
    if projection:
        movie_name_base = 'multipanel_projection_%s_only' % field
    else:
        movie_name_base = 'multipanel_slice_%s_only' % field
    movie_name = pt.get_fig_name(movie_name_base, profile, compare, tctf, beta = beta, use_tctf = 1, \
                           cr=cr, crdiff = diff, crstream = stream, crheat = heat, \
                              sim_fam = '',  loc = '../.')
    print(movie_name[6:-4])
    os.rename('%s.mov' % field, '../%s.mov' % movie_name[6:-4])
    png_files = glob.glob('*.png')
    print(cwd)
    # print(png_files)

    for pic in png_files:
        os.remove(pic)
    os.chdir(cwd)
def plot_power_spectrum(sim,
                        compare,
                        tctf,
                        beta,
                        cr,
                        diff=0,
                        stream=0,
                        heat=0,
                        output=50,
                        field='drho',
                        work_dir='../../simulations/production'):

    tctf_list, beta_list, cr_list, diff_list, stream_list, heat_list\
                        = pt.generate_lists(compare, tctf, crdiff = diff, cr = cr)
    print(tctf_list, beta_list, cr_list, diff_list, stream_list, heat_list)

    color_list = palettable.scientific.sequential.Batlow_6.mpl_colors
    fig, ax = plt.subplots(figsize=(6, 6))
    for i, tctf in enumerate(tctf_list):
        color = color_list[i]
        # load the simulation
        sim_loc = pt.get_sim_location(sim, tctf, beta_list[i], cr_list[i], \
                                           diff = diff_list[i], stream = stream_list[i],
                                           heat = heat_list[i], work_dir = work_dir)
        ds_path = '%s/DD%04d/DD%04d' % (sim_loc, output, output)
        if not os.path.isfile(ds_path):
            print('nope')
            continue
        ds = yt.load(ds_path)

        label = pt.get_label_name(compare, tctf, beta_list[i], cr_list[i], crdiff = diff_list[i], \
                       crstream = stream_list[i], crheat = heat_list[i])
        k, P_k = pt.make_power_spectrum(ds, field)
        ax.loglog(k, P_k, label=label, linewidth=3, color=color)

    ax.set_xlabel('k')
    ax.set_ylabel('P(k)')
    ax.set_xlim(1, 200)
    ax.set_ylim(1e-6, 1e2)
    ax.legend(loc=1)
    fig.tight_layout()
    figname = pt.get_fig_name('power_spectrum', sim, compare, tctf_list[0], beta_list[0], \
                              cr_list[0], diff_list[0], loc = '../../plots/production')
    plt.savefig(figname, dpi=300)
def plot_multipanel_slices(field, output, sim, compare, tctf, beta = 100, cr = 0,\
                           crdiff = 0, crstream = 0, crheat = 0, fixed_time = 0,
                           weight_field = 'density', projection = False, work_dir = '.'):
    ds_loc_list, label_list  = pt.get_sim_list(sim, compare, tctf, beta = beta,  cr = cr, \
                crdiff = diff, crstream = stream, crheat = heat, work_dir = work_dir, sim_fam = sim_fam)
    print(ds_loc_list)
    fig, ax = plt.subplots(ncols=len(ds_loc_list),
                           nrows=1,
                           figsize=(1.5 * len(ds_loc_list), 3.8),
                           constrained_layout=True)
    for i, ds_loc in enumerate(ds_loc_list):
        print(ds_loc)
        if fixed_time:
            output_list = [100, 33, 10, 3, 1]
            output = output_list[i]
        if not os.path.isfile('%s/DD%04d/DD%04d' % (ds_loc, output, output)):
            continue
        ds = ytf.load('%s/DD%04d/DD%04d' % (ds_loc, output, output))
        ds.add_field(('gas', 'invT'), function=_inv_T, units='')
        if projection:
            s = yt.ProjectionPlot(ds,
                                  'x', ('gas', field),
                                  center=(0, 0, 1),
                                  width=(1, 1.8),
                                  weight_field=weight_field)
        else:
            s = yt.SlicePlot(ds,
                             'x', ('gas', field),
                             center=(0, 0, 1),
                             width=(1, 1.8))
        s.save()
        s.set_buff_size(1024)
        frb = s.frb

        xbins = frb['y'].in_units('kpc')
        ybins = frb['z'].in_units('kpc')
        if field == 'density':
            data_norm = rho0
        elif field == 'temperature':
            data_norm = T0
        else:
            data_norm = p0

        data = frb[field] / data_norm

        if field == 'density':
            vmin = 1e-1
            vmax = 3
            if projection:
                vmax = 1.5
            label = '$\\rho / \\rho_0 $'
        elif field == 'temperature':
            vmin = 5e4 / T0
            vmax = 5e6 / T0
            label = 'T / T$_0$'
        elif field == 'cr_eta':
            cr_eta = pt.get_cr_eta(ds)
            vmin = cr_eta / 100
            vmax = cr_eta * 100
            label = 'P$_c$ / P$_g$'
        elif field == 'cr_pressure':
            label = 'P$_c$ / P$_{c,0}$'
            vmin = 0.2
            vmax = 2

        cmap = pt.get_cmap(field)
        pcm = ax[i].pcolormesh(xbins, ybins, data, cmap = cmap, norm = LogNorm(),\
                               vmax = vmax, vmin = vmin, zorder = 1)

        ax[i].set_aspect('equal')
        #        ax[i].tick_params(direction='in', top=True, right=True, zorder = 10)
        ax[i].set_xticklabels([])
        if i == 0:
            ax[i].set_ylabel('z (kpc)')
        else:
            ax[i].set_yticklabels([])
        H_kpc = 43.85  # scale height in kpc
        ax[i].axhline(0.8 * H_kpc,
                      linestyle='dashed',
                      color='black',
                      linewidth=0.7)
        ax[i].axhline(1.2 * H_kpc,
                      linestyle='dashed',
                      color='black',
                      linewidth=0.7)
        if field == 'temperature':
            ax[i].set_xlabel(label_list[i], fontsize=10)
        else:
            ax[i].set_title(label_list[i], fontsize=10)

    fig.tight_layout()
    fig.subplots_adjust(bottom=0.2)
    pos_l = ax[0].get_position().get_points()
    pos_r = ax[-1].get_position().get_points()

    dx = 0.02
    dy = pos_r[1][1] - pos_r[0][1]
    cbar_y = pos_r[0][1]
    cbar_x = pos_r[1][0] + .02
    print(cbar_x, cbar_y)
    cbax = fig.add_axes([cbar_x, cbar_y, dx, dy])
    if field == 'cr_pressure':
        cbar = fig.colorbar(pcm,
                            cax=cbax,
                            orientation='vertical',
                            ticks=[0.2, 2])
        cbar.ax.set_yticklabels(['0.2', '2'])
    else:
        cbar = fig.colorbar(pcm, cax=cbax, orientation='vertical')
    cbar.set_label(label)

    if projection:
        fig_base = '%s_multipanel_projection' % field
    else:
        fig_base = '%s_multipanel_slice' % field
    figname = pt.get_fig_name(fig_base, sim, compare, tctf, beta = beta, use_tctf = 1, \
                           cr=cr, crdiff = crdiff, crstream = crstream, crheat = crheat, \
                              time = output, sim_fam = sim_fam)
    plt.savefig(figname, dpi=300, bbox_inches='tight', pad_inches=0.1)
示例#4
0
def plot_density_fluctuation_growth(sim,
                                    compare,
                                    tctf,
                                    beta,
                                    cr,
                                    diff=0,
                                    stream=0,
                                    heat=0,
                                    zstart=0.8,
                                    zend=1.2,
                                    T_cold=3.33333e5,
                                    fs=12,
                                    field='density',
                                    work_dir='../../simulations/',
                                    grid_rank=3):


    tctf_list, beta_list, cr_list, diff_list, stream_list, heat_list \
        = pt.generate_lists(compare, tctf, beta = beta, crdiff = crdiff, cr = cr, crstream = stream, crheat = heat)

    print(tctf_list, beta_list, cr_list, diff_list, stream_list, heat_list)

    ncols = 3
    fig, ax = plt.subplots(nrows=1,
                           ncols=ncols,
                           figsize=(4 * ncols, 3.8),
                           sharex=True,
                           sharey=False)
    for col in range(ncols):
        ax[col].set_yscale('log')
        ax[col].set_xlim(0, 10)
        ax[col].set_xlabel('$t / t_{cool}$', fontsize=fs)

    ax[0].set_ylim(1e-2, 5)
    ax[1].set_ylim(5e-3, 4)
    ax[2].set_ylim(2e-2, 10)
    ax[0].set_ylabel('Density Fluctuation', fontsize=fs)
    ax[1].set_ylabel('Cold Mass Fraction', fontsize=fs)
    ax[2].set_ylabel('Cold Mass Flux', fontsize=fs)

    gamma = 5. / 3.
    time_list = np.arange(0, 12, 1)
    wcool = 1.0 / (gamma * 1.0)
    pi = (5. / 3.) * wcool
    linecolor = 'black'

    if pt.dark_mode:
        linecolor = 'white'
    if compare == 'tctf' or compare == 'cr':
        pi = 1.0
        ax[0].plot(time_list, 0.02*np.exp(pi*time_list), color = linecolor,\
                   linestyle = 'dashed', label = 'Linear Theory', linewidth = 3)


#    if compare == 'cr':
#        pi = 2./3.
#        ax[0].plot(time_list, 0.02*np.exp(pi*time_list), color = linecolor,\
#                   linestyle = 'dotted', label = 'Linear Theory, $\\eta \\gg 1$', linewidth = 3)

#       pi = 1./3.
#       ax[0].plot(time_list, 0.02*np.exp(pi*time_list), color = linecolor,\
#                  linestyle = 'dotted', label = 'Linear Theory, $\\eta \\gg 1$', linewidth = 3)
#
#      pi = 1./12.
#       ax[0].plot(time_list, 0.02*np.exp(pi*time_list), color = linecolor,\
#                  linestyle = 'dotted', label = 'Linear Theory, $\\eta \\gg 1$', linewidth = 3)

    cpal = pt.get_color_list(compare)

    for col, plot_type in enumerate(
        ['density_fluctuation', 'cold_fraction', 'cold_flux']):
        for i, tctf in enumerate(tctf_list):
            time_list, data_list = pt.get_time_data(plot_type, sim, tctf, beta_list[i], cr_list[i], use_mpi = use_mpi, \
                                           diff = diff_list[i], stream = stream_list[i], heat = heat_list[i],
                                           field = field, zstart = zstart, zend = zend, grid_rank = grid_rank,
                                           load = load, save = save, work_dir = work_dir, sim_fam = sim_fam)

            label = pt.get_label_name(compare, tctf, beta_list[i], cr_list[i], crdiff = diff_list[i], \
                                      crstream = stream_list[i], crheat = heat_list[i], counter = i)
            #            linestyle = pt.get_linestyle(compare, tctf, beta_list[i], cr_list[i], crdiff = diff_list[i], \
            #                                              crstream = stream_list[i], crheat = heat_list[i], counter = i)
            linestyle = 'solid'
            ax[col].plot(time_list / tctf,
                         data_list,
                         linewidth=3,
                         linestyle=linestyle,
                         label=label,
                         color=cpal[i])
            ax[col].tick_params(labelsize=fs)

            if resolution_compare:
                time_list, data_list = pt.get_time_data(plot_type, sim, tctf, beta_list[i], cr_list[i], use_mpi = use_mpi,\
                                           diff = diff_list[i], stream = stream_list[i], heat = heat_list[i],
                                           field = field, zstart = zstart, zend = zend, grid_rank = grid_rank,
                                           load = load, save = save, work_dir = work_dir, sim_fam = 'production/high_res')
                ax[col].plot(time_list / tctf,
                             data_list,
                             linewidth=2,
                             linestyle='dotted',
                             label=label,
                             color=cpal[i])
            if mhd_compare:
                linestyle_list = ['dashed', 'dotted']
                alpha_list = [.8, .8]
                for j, compare_beta in enumerate([3, 'inf']):
                    time_list, data_list = pt.get_time_data(plot_type, sim, tctf, compare_beta, cr_list[i], use_mpi =use_mpi,\
                                                        diff = diff_list[i], stream = stream_list[i], heat = heat_list[i],
                                           field = field, zstart = zstart, zend = zend, grid_rank = grid_rank,
                                                            load = load, save = save, work_dir = work_dir, sim_fam = sim_fam)
                    label = None
                    ax[col].plot(time_list / tctf,
                                 data_list,
                                 linewidth=2,
                                 linestyle=linestyle_list[j],
                                 alpha=alpha_list[j],
                                 label=label,
                                 color=cpal[i])
            if compare == 'stream' and stream_list[i] > 0:
                linestyle = 'dotted'
                time_list, data_list = pt.get_time_data(plot_type, sim, tctf, beta_list[i], cr_list[i], use_mpi =use_mpi,\
                                                        diff = diff_list[i], stream = stream_list[i], heat = 0,
                                           field = field, zstart = zstart, zend = zend, grid_rank = grid_rank,
                                                           load = load, save = save, work_dir = work_dir, sim_fam = sim_fam)
                label = None
                ax[col].plot(time_list / tctf,
                             data_list,
                             linewidth=2,
                             linestyle=linestyle,
                             label=label,
                             color=cpal[i])

    if compare == 'transport' or compare == 'transport_relative':
        ax[0].legend(fontsize=7, loc=3, ncol=2)
    else:
        ax[1].legend(fontsize=8, loc=2)
    fig.tight_layout()
    figname = pt.get_fig_name('dens_cfrac_cflux_growth', sim, compare, \
                              tctf, beta, cr, diff,  sim_fam = sim_fam)
    plt.savefig(figname, dpi=300)
示例#5
0
def make_plot(field,
              compare,
              tctf=0.3,
              cr=1,
              weighted=True,
              nbins=100,
              work_dir='../../simulations',
              plot_dir='../../plots',
              sim_fam='production'):

    sim_list = pt.generate_sim_list(compare, tctf=tctf, cr=cr)
    label_list = pt.generate_label_list(compare, tctf=tctf, cr=cr)

    if field == 'cr_eta':
        sim_list = sim_list[1:]
        label_list = label_list[1:]
    #color_list = pt.get_color_list(compare)
    if field == 'density':
        pal = sns.cubehelix_palette(len(sim_list), rot=-.25, light=.7)
    elif field == 'temperature':
        pal = palettable.scientific.sequential.LaJolla_13.mpl_colors[2:-1]
        if compare == 'transport_pdf':
            pal = palettable.scientific.sequential.LaJolla_16.mpl_colors[2:-1]
    elif field == 'cr_eta':
        pal = sns.cubehelix_palette(len(sim_list) + 1)[1:]
    elif field == 'plasma_beta':
        pal = palettable.cmocean.sequential.Ice_16.mpl_colors[2:-1]

    x, y = format_data_for_pdf(field,
                               sim_list,
                               label_list,
                               weighted=weighted,
                               nbins=nbins,
                               work_dir=work_dir,
                               sim_fam=sim_fam)

    if field == 'density':
        xlabel = 'Log Number Density (cm$^{-3}$)'
        #        xlims = (-28.5, -25.8)
        xlims = (-28.7, -25.6)
        xlims -= log_mumh
        aspect = 8
        use_label = True
        ylims = (0, 1)
    elif field == 'temperature':
        xlabel = 'Log Temperature (K)'
        #       xlims = (3.5, 7)
        xlims = (4.3, 6.8)
        aspect = 8.137  #7.5
        use_label = False
        ylims = (0, 1)
    elif field == 'cr_eta':
        xlabel = 'Log ($P_c / P_g$)'
        #      xlims = (np.log10(cr) - 2, np.log10(cr) + 2)
        xlims = (np.log10(cr) - 0.8, np.log10(cr) + 1.7)
        aspect = 8
        use_label = False
        ylims = (0, 1)
    elif field == 'plasma_beta':
        xlabel = 'Log ($P_g / P_b$)'
        xlims = (-3, 2)
        aspect = 8
        use_label = True

    ylims = (0, 1)

    g = create_pdf_plot(x,
                        y,
                        len(sim_list),
                        height=1,
                        aspect=aspect,
                        pal=pal,
                        use_label=use_label)
    ax = g.axes

    ax[-1][0].set_xlabel(xlabel, color='black', fontsize=16)
    ax[-1][0].tick_params(axis='x',
                          colors='black',
                          bottom=True,
                          labelsize='large')
    g.set(xlim=xlims)
    g.set(ylim=ylims)
    if field == 'cr_eta':
        field = 'creta'
    fig_basename = 'pdf_%s' % field
    if weighted:
        fig_basename += '_weighted'
    figname = pt.get_fig_name(fig_basename, 'isocool', compare, \
                              tctf, 100, cr, sim_fam = sim_fam,\
                              loc = plot_dir)
    g.savefig(figname, dpi=300, bbox_inches='tight', pad_inches=0)
def plot_density_fluctuation(output, sim, compare, tctf, beta, cr, diff = 0, stream = 0, heat = 0,
                              T_cold = 3.3333333e5, zstart = 0.8, zend = 1.2, relative = 0, 
                              work_dir = '../../simulations/', grid_rank = 3):

    tctf_list, beta_list, cr_list, diff_list, stream_list, heat_list\
                        = pt.generate_lists(compare, tctf, crdiff = diff, cr = cr, beta = beta)
    mask = cr_list > 0
    tctf_list = tctf_list[mask]
    beta_list = beta_list[mask]
    cr_list = cr_list[mask]
    diff_list = diff_list[mask]
    stream_list = stream_list[mask]
    heat_list = heat_list[mask]
        
    tctf_list = [0.1, 0.3, 1, 3, 10]

    print(tctf_list, beta_list, cr_list, diff_list, stream_list, heat_list)
    
    fig, ax = plt.subplots(nrows=1, ncols = 1, figsize = (4.4, 4), sharex = True, sharey = False)
    
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlim(.09, 10)
    ax.set_xlabel('$t_{cool} / t_{ff}$')


    ax.set_ylim(1e-2, 5e2)
    ax.set_ylabel('CR Pressure Ratio')

    color_list = pt.get_color_list(compare)    
    for i in range(len(cr_list)):
        for col, plot_type in enumerate(['cold_creta']):
            x_list = []
            y_list = []
            err_list = []

            x_rel = []
            y_rel = []
            err_rel_list = []
            for tctf in tctf_list:
                time_list, data_list = pt.get_time_data(plot_type, sim, tctf, beta_list[i], cr_list[i], \
                                           diff = diff_list[i], stream = stream_list[i], heat = heat_list[i],
                                           T_min = T_cold, zstart = zstart, zend = zend, grid_rank = grid_rank,
                                           load = load, save = save, work_dir = work_dir, sim_fam = sim_fam)
                if len(data_list) > 0:
                    data = np.nan_to_num(data_list[output-10:output+10])
                    mean = np.mean(data)
                    err  = np.std(data)
                    x_list.append(tctf)
                    y_list.append(mean)
                    err_list.append(err)
                
                if relative:
                    time_nocr, data_nocr = pt.get_time_data(plot_type, sim, tctf, beta_list[i], cr = 0, \
                                           diff = 0, stream = 0, heat = 0,
                                           T_min = T_cold, zstart = zstart, zend = zend, grid_rank = grid_rank,
                                           load = load, save = save, work_dir = work_dir, sim_fam = sim_fam)
                    if len(data_list) > 0 and len(data_nocr) > 0:
                        data_nocr = np.nan_to_num(data_nocr[output-10:output+10])
                        mean_nocr = np.mean(data_nocr)
                        err_nocr  = np.std(data_nocr)
                        data =  np.nan_to_num(data_list[output-10:output+10])
                        mean_cr = np.mean(data)
                        err_cr = np.std(data)
                        mean = mean_cr / mean_nocr
                        err = mean * np.sqrt( np.power(err_nocr / mean_nocr, 2) + np.power(err_cr / mean, 2)) 

                        x_rel.append(tctf)
                        y_rel.append(mean)
                        err_rel_list.append(err)
                    
            label = pt.get_label_name(compare, tctf, beta_list[i], cr_list[i], crdiff = diff_list[i], \
                                      crstream = stream_list[i], crheat = heat_list[i], counter = i)
        
            color = color_list[i]
            marker = 'o'
            if label is None:
                marker = None

            linestyle = pt.get_linestyle(compare, tctf, beta_list[i], cr_list[i], crdiff = diff_list[i], \
                                              crstream = stream_list[i], crheat = heat_list[i], counter = i)

            x_list = np.array(x_list)
            y_list = np.array(y_list)
            err_list = np.array(err_list)


            if relative == 0:    
                mask = y_list > 0
                x_list   =   x_list[mask]
                y_list   =   y_list[mask]
                err_list = err_list[mask]

                ax.plot(x_list, y_list, color = color_list[i], label = label, 
                            linewidth = 2, marker = marker, linestyle = linestyle)
                ax.errorbar(x_list, y_list, err_list, color = color_list[i])
            elif relative and cr_list[i] > 0:
                ax.plot(x_rel, y_rel, color = color_list[i], label = label,
                            linewidth = 2, marker = marker, linestyle = linestyle)
                ax.errorbar(x_rel, y_rel, err_rel_list, color = color_list[i])
                ax.axhline(y = 1, linestyle = 'dashed', color = 'gray', linewidth = 1)
    ax.legend(fontsize = 8)
    fig.tight_layout()
    fig_basename = 'creta_tctf'
    if relative:
        fig_basename += '_relative'
    figname = pt.get_fig_name(fig_basename, sim, compare, \
                              tctf, beta, cr, crdiff = diff, crstream = stream, \
                              crheat = heat, time = output, sim_fam = sim_fam,\
                              loc = '../../plots')
    print(figname)
    plt.savefig(figname, dpi = 300)
def plot_density_fluctuation_growth(sim,
                                    compare,
                                    tctf,
                                    beta,
                                    cr,
                                    diff=0,
                                    stream=0,
                                    heat=0,
                                    zstart=0.8,
                                    zend=1.2,
                                    T_cold=3.33333e5,
                                    fs=12,
                                    field='density',
                                    work_dir='../../simulations/',
                                    grid_rank=3):


    tctf_list, beta_list, cr_list, diff_list, stream_list, heat_list \
        = pt.generate_lists(compare, tctf, beta = beta, crdiff = crdiff, cr = cr)

    print(tctf_list, beta_list, cr_list, diff_list, stream_list, heat_list)

    ncols = 2
    fig, ax = plt.subplots(nrows=1,
                           ncols=ncols,
                           figsize=(4 * ncols, 3.8),
                           sharex=True,
                           sharey=False)
    for col in range(ncols):
        ax[col].set_yscale('log')
        ax[col].set_xlim(0, 10)
        ax[col].set_xlabel('$t / t_{cool}$', fontsize=fs)

    ax[0].set_ylim(1, 100)
    ax[1].set_ylim(1, 300)
    ax[0].set_ylabel('Clump Size', fontsize=fs)
    ax[1].set_ylabel('Number of Clumps', fontsize=fs)

    cpal = pt.get_color_list(compare)

    for i, tctf in enumerate(tctf_list):
        time_list, clump_data = pt.get_time_data('clump', sim, tctf, beta_list[i], cr_list[i], use_mpi = True,\
                                           diff = diff_list[i], stream = stream_list[i], heat = heat_list[i],
                                           field = field, zstart = zstart, zend = zend, grid_rank = grid_rank,
                                           load = load, save = save, work_dir = work_dir, sim_fam = sim_fam)
        #        print(clump_data[0])
        #        clump_data = list(clump_data)
        print(len(clump_data), len(clump_data[0]))
        #        n_clumps, clump_size, clump_std = list(zip(*clump_data))
        n_clumps, clump_size, clump_std = clump_data
        print(n_clumps, clump_size)

        label = pt.get_label_name(compare, tctf, beta_list[i], cr_list[i], crdiff = diff_list[i], \
                                      crstream = stream_list[i], crheat = heat_list[i], counter = i)

        linestyle = 'solid'
        ax[0].plot(time_list / tctf,
                   clump_size,
                   linewidth=3,
                   linestyle=linestyle,
                   label=label,
                   color=cpal[i])
        ax[0].tick_params(labelsize=fs)
        ax[1].plot(time_list / tctf,
                   n_clumps,
                   linewidth=3,
                   linestyle=linestyle,
                   color=cpal[i])

    ax[0].legend()
    fig.tight_layout()
    figname = pt.get_fig_name('clump_growth', sim, compare, \
                              tctf, beta, cr, diff,  sim_fam = sim_fam)
    plt.savefig(figname, dpi=300)
def plot_density_fluctuation(output,
                             sim,
                             compare,
                             tctf,
                             beta,
                             cr,
                             diff=0,
                             stream=0,
                             heat=0,
                             T_cold=3.3333333e5,
                             zstart=0.8,
                             zend=1.2,
                             relative=0,
                             work_dir='../../simulations/',
                             grid_rank=3):

    tctf_list, beta_list, cr_list, diff_list, stream_list, heat_list\
                        = pt.generate_lists(compare, tctf, crdiff = diff, cr = cr, beta = beta, cr_only = 1)
    tctf_list = [0.1, 0.3, 1, 3]
    all_cr_list = [0.01, 0.1, 1, 10]
    print(tctf_list, beta_list, cr_list, diff_list, stream_list, heat_list)

    fig, ax = plt.subplots(nrows=1,
                           ncols=3,
                           figsize=(4.4 * 3, 4),
                           sharex=True,
                           sharey=False)
    for col in range(3):
        ax[col].set_xscale('log')
        ax[col].set_yscale('log')
        ax[col].set_xlim(5e-3, 5e2)
        ax[col].set_xlabel('$ P_c / P_g $')

    if relative == 0:
        ax[0].set_ylim(1e-2, 5)
        ax[1].set_ylim(5e-3, 4)
        ax[2].set_ylim(5e-3, 4)
        ax[0].set_ylabel('Density Fluctuation')
        ax[1].set_ylabel('Cold Mass Fraction')
        ax[2].set_ylabel('Cold Mass Flux')
    else:
        ax[0].set_ylim(1e-2, 10)
        ax[1].set_ylim(1e-2, 100)
        ax[2].set_ylim(1e-2, 10)
        ax[0].set_ylabel('Relative Density Fluctuation')
        ax[1].set_ylabel('Relative Cold Mass Fraction')
        ax[2].set_ylabel('Relative Cold Mass Flux')

    color_list = pt.get_color_list('tctf')
    marker = 'o'

    for acr in all_cr_list:
        cr_list = len(beta_list) * [acr]
        for i in range(len(cr_list)):
            for col, plot_type in enumerate(
                ['density_fluctuation', 'cold_fraction', 'cold_flux']):
                for j, tctf in enumerate(tctf_list):
                    data = 0
                    creta = 0
                    err = 0
                    time_list, data_list = pt.get_time_data(plot_type, sim, tctf, beta_list[i], cr_list[i], \
                                           diff = diff_list[i], stream = stream_list[i], heat = heat_list[i],
                                           T_min = T_cold, zstart = zstart, zend = zend, grid_rank = grid_rank,
                                           load = load, save = save, work_dir = work_dir, sim_fam = sim_fam)
                    creta_time_list, creta_list = pt.get_time_data('cold_creta', sim, tctf, beta_list[i], cr_list[i], \
                                           diff = diff_list[i], stream = stream_list[i], heat = heat_list[i],
                                           T_min = T_cold, zstart = zstart, zend = zend, grid_rank = grid_rank,
                                           load = load, save = save, work_dir = work_dir, sim_fam = sim_fam)
                    if len(data_list) > 0 and len(creta_list) > 0:
                        data = np.nan_to_num(data_list[output - 10:output +
                                                       10])
                        creta = np.nan_to_num(creta_list[output - 10:output +
                                                         10])

                    if acr == 0.01 and i == 0 and col == 0:
                        label = pt.get_label_name('tctf', tctf, beta_list[i], cr_list[i], crdiff = diff_list[i], \
                                                  crstream = stream_list[i], crheat = heat_list[i], counter = i)
                    else:
                        label = None
                    ax[col].scatter(np.mean(creta),
                                    np.mean(data),
                                    color=color_list[j],
                                    label=label,
                                    marker=marker)
                    ax[col].errorbar(np.mean(creta),
                                     np.mean(data),
                                     xerr=np.std(creta),
                                     yerr=np.std(data),
                                     color=color_list[j])

    ax[0].legend(fontsize=8)
    fig.tight_layout()
    fig_basename = 'dens_cfrac_cflux_creta'
    figname = pt.get_fig_name(fig_basename, sim, compare, \
                              tctf, beta, cr, crdiff = diff, crstream = stream, \
                              crheat = heat, time = output, sim_fam = sim_fam,\
                              loc = '../../plots')
    print(figname)
    plt.savefig(figname, dpi=300)