def make_base_plot(ds, field, width, depth, cm):
    width = parse_lengths(ds, width)
    depth = parse_lengths(ds, depth)
    normal, north, center = plot_basics(ds)
    # Make a smaller data object to hopefully speed up computation. We make
    # it bigger than the projection region to ensure nothing is missed.
    ds_select = ds.disk(center=center,
                        normal=normal,
                        radius=2 * width,
                        height=2 * depth)

    plot = yt.OffAxisProjectionPlot(ds,
                                    fields=field,
                                    center=center,
                                    normal=normal,
                                    north_vector=north,
                                    width=width,
                                    depth=depth,
                                    data_source=ds_select)
    plot.set_cmap(field, cm)
    plot.set_zlim(field, 0.1, 10**3)
    plot.annotate_timestamp(corner='upper_left',
                            redshift=True,
                            draw_inset_box=True)
    plot.set_unit(field, "msun/pc**2")
    return plot
Esempio n. 2
0
def make_off_axis_projection(ds, vec, north_vec, ion_fields, center, width, data_source, radius, weight_field=None, dir=None):
    """
    Use OffAxisProjectionPlot to make projection (cannot specify resolution)
    """
    p = yt.OffAxisProjectionPlot(ds, vec, ion_fields, center=center, width=width, 
                                 data_source=data_source, north_vector=north_vec, weight_field=weight_field)
    p.hide_axes()
    p.annotate_scale()
    p.annotate_timestamp(redshift=True)
    r = radius.in_units('kpc')
    p.annotate_sphere(center, (r, 'kpc'), circle_args={'color':'white', 'alpha':0.5, 'linestyle':'dashed', 'linewidth':5})
    for field in ion_fields:
        p.set_cmap(field, 'dusk')
        set_image_details(p, field, True)
        p.set_background_color(field)
    if dir is None:
        dir = 'face/'
    p.save(os.path.join('images', dir))
    return p.frb
Esempio n. 3
0
def main():

    file = sys.argv[1]
    part_file = file[:-12] + 'part' + file[-5:]
    ds = yt.load(file, particle_filename=part_file)
    L = [0.0, 0.0, 1.0]
    myf.set_normal(L)
    center_pos = np.array([0.0, 0.0, 0.0])

    args = parse_inputs()
    field = args.field
    dd = ds.all_data()
    center_vel = dd['Center_Velocity']
    center_pos = dd['Center_Position']
    part_pos = dd['All_Particle_Positions']
    part_mass = dd['All_Particle_Masses']
    print("center vel =", myf.get_center_vel())
    print("center pos for field calculations", myf.get_center_pos())
    print("particle positions are:", myf.get_part_pos())
    print("particle masses are:", myf.get_part_mass())
    center_pos = np.array([0.0, 0.0, 0.0])
    proj = yt.OffAxisProjectionPlot(
        ds,
        L, [field, 'cell_mass', 'velx_mw', 'vely_mw', 'magx_mw', 'magy_mw'],
        center=(center_pos, 'AU'),
        width=(args.ax_lim, 'AU'),
        depth=(args.slice_thickness, 'AU'))
    image = (proj.frb.data[simfo['field']] / thickness.in_units('cm')).value
    velx_full = (
        proj.frb.data[('gas', 'Projected_Velocity_mw')].in_units('g*cm**2/s') /
        thickness.in_units('cm')).value
    vely_full = (proj.frb.data[('gas', 'velz_mw')].in_units('g*cm**2/s') /
                 thickness.in_units('cm')).value
    magx = (proj.frb.data[
        ('gas', 'Projected_Magnetic_Field_mw')].in_units('g*gauss*cm') /
            thickness.in_units('cm')).value
    magy = (proj.frb.data[('gas', 'magz_mw')].in_units('g*gauss*cm') /
            thickness.in_units('cm')).value
    mass = (proj.frb.data[('gas', 'cell_mass')].in_units('cm*g') /
            thickness.in_units('cm')).value
    proj.save()
def xray_3color(proj_angle=30.):
    dir = '/home/ychen/d9/FLASH4/stampede/0529_L45_M10_b1_h1/'
    fname = dir + 'MHD_Jet_hdf5_plt_cnt_0620'
    #fname = '~/d9/yt_testdata/GasSloshing/sloshing_nomag2_hdf5_plt_cnt_0150'

    energy_ranges = [(0.3, 1.5), (1.5, 3.5), (3.5, 7.0)]
    fields = ['xray_emissivity_%.1f_%.1f_keV' % er for er in energy_ranges]

    ds = yt.load(fname)

    for energy_range in energy_ranges:
        print add_xray_emissivity_field(ds, *energy_range, \
                with_metals=False, constant_metallicity=0.02,\
                filename='/d/d9/ychen/yt_testdata/apec_emissivity.h5')

    normal = (0.0, np.tan(proj_angle / 180. * np.pi), 1.0)
    north = (0.0, 0.1, 1.0)
    max_lv = 6

    plot = yt.OffAxisProjectionPlot(ds, normal, fields, width=((40,'kpc'), (80,'kpc')),\
                                    depth=(80, 'kpc'), max_level=max_lv, north_vector=north)
    #plot.zoom(10)
    #plot.frb.export_fits(fname+'.fits', fields=fname)
    #plot.save()

    ext = ds.arr([-20, 20, -40, 40], input_units='kpc')

    r = plot['xray_emissivity_0.3_1.5_keV'].image._A
    g = plot['xray_emissivity_1.5_3.5_keV'].image._A
    b = plot['xray_emissivity_3.5_7.0_keV'].image._A
    image = np.array([r, g, b]).transpose([1, 2, 0])
    image = image / image.max()
    plt.figure(figsize=(10, 20))
    plt.imshow(image, origin='lower', extent=ext)
    plt.xlabel('(kpc)')
    #plt.ylabel('(kpc)')
    plt.savefig('xray_apec002_3color%i_lv%i.png' %
                (int(proj_angle), int(max_lv)))
Esempio n. 5
0
         velx_full = proj_dict[proj_dict_keys[1]]
         vely_full = proj_dict[proj_dict_keys[2]]
         velz_full = proj_dict[proj_dict_keys[3]]
         magx = proj_dict[proj_dict_keys[4]]
         magy = proj_dict[proj_dict_keys[5]]
             
 elif args.use_angular_momentum != 'False':
     proj_root_rank = int(rank/7)*7
     #proj_dict = {simfo['field'][1]:[]}
     proj_dict = {simfo['field'][1]:[], 'Projected_Velocity_x':[], 'Projected_Velocity_y':[], 'Projected_Velocity_z':[], 'Projected_Magnetic_Field_x':[], 'Projected_Magnetic_Field_y':[], 'Projected_Magnetic_Field_z':[]}
     proj_dict_keys = str(proj_dict.keys()).split("['")[1].split("']")[0].split("', '")
     #proj_field_list =[simfo['field']]
     proj_field_list =[simfo['field'], ('gas', 'Projected_Velocity_x'), ('gas', 'Projected_Velocity_y'), ('gas', 'Projected_Velocity_z'), ('gas', 'Projected_Magnetic_Field_x'), ('gas', 'Projected_Magnetic_Field_y'), ('gas', 'Projected_Magnetic_Field_z')]
     
     for field in yt.parallel_objects(proj_field_list):
         proj = yt.OffAxisProjectionPlot(ds, L, field, width=(x_width/2, 'AU'), weight_field=weight_field, method='integrate', center=(center_pos, 'AU'), depth=(args.slice_thickness, 'AU'))
         if 'mag' in str(field):
             if weight_field == None:
                 proj_array = np.array(proj.frb.data[field].in_units('cm*gauss')/thickness.in_units('cm'))
             else:
                 proj_array = np.array(proj.frb.data[field].in_units('gauss'))
         else:
            if weight_field == None:
                 proj_array = np.array(proj.frb.data[field].in_cgs()/thickness.in_units('cm'))
            else:
                proj_array = np.array(proj.frb.data[field].in_cgs())
         if rank == proj_root_rank:
             proj_dict[field[1]] = proj_array
         else:
             file = open(pickle_file.split('.pkl')[0] + '_proj_data_' + str(proj_root_rank)+ str(proj_dict_keys.index(field[1])) + '.pkl', 'wb')
             pickle.dump((field[1], proj_array), file)
# To see all the possible field values contained within the hydro (plt) and particle (part) files:
# for e in ds.derived_field_list:
#     print(e)
# ds.print_stats()
# print(ds.derived_field_list)
# print("-"*20 + "\n", ad['x'])
# print(len(ad['x']))
# print(len(ad['temperature']))

# Say I want to make a slice plot of the density field.
field = 'density'  # dens, temp, and pres are some shorthand strings recognized by yt.
# ax = 'y' # the axis our slice plot will be "looking down on".

L = [0, 0, 1]  # vector normal to cutting plane

plot_ = yt.OffAxisProjectionPlot(ds, L, field)
im_name = 'off_axis_yt.png'
plot_.save(im_name)
# plot_ = yt.SlicePlot(ds, ax, field)
# plot_ = yt.ProjectionPlot(ds, ax, field)
# plot_.set_cmap(field, "binary")
#
# plot_.annotate_timestamp()
# plot_.annotate_grids()
#
# plot_.annotate_title('Testplot')
# plot_.annotate_scale()

#plot_.annotate_magnetic_field()

#plot_.show()
Esempio n. 7
0
	width = 2*radial_extent

	log("Reading amiga center for halo in %s" % fn)
	c = read_amiga_center(amiga_data, fn, ds)
	rvir = read_amiga_rvir(amiga_data, fn, ds)

	log('Finding Angular Momentum of Galaxy')
	sp = ds.sphere(c, (15, 'kpc'))
	L = sp.quantities.angular_momentum_vector(use_gas=False, use_particles=True, particle_type='PartType0')
	L, E1, E2 = ortho_find(L)

	one = ds.arr([.5, .5, .5], 'Mpc')
	box = ds.box(c-one, c+one)

	log('Generating Plot 1')
	p1 = yt.OffAxisProjectionPlot(ds, E1, ('gas', 'H_number_density'), center=c, 
		width=width, data_source=box, north_vector=L, weight_field=None)
	log('Generating Plot 2')
	p2 = yt.OffAxisProjectionPlot(ds, E1, ('gas', 'Mg_p1_number_density'), center=c, 
		width=width, data_source=box, north_vector=L, weight_field=None)
	log('Generating Plot 3')
	p3 = yt.OffAxisProjectionPlot(ds, E1, ('gas', 'O_p5_number_density'), center=c, 
		width=width, data_source=box, north_vector=L, weight_field=None)
	log('Generating Plot 4')
	p4 = yt.OffAxisProjectionPlot(ds, E1, ('gas', 'Si_p3_number_density'), center=c,
		width=width, data_source=box, north_vector=L, weight_field=None)

	print("Stitching together")
	fig = GridFigure(2, 2, top_buffer=0.01, bottom_buffer=0.01, left_buffer=0.01, right_buffer=0.13, vertical_buffer=0.01, horizontal_buffer=0.01, figsize=(9,8))

	# Actually plot in the different axes
	plot1 = fig[0].imshow(p1.frb['H_number_density'], norm=LogNorm())
Esempio n. 8
0
def main():
    rank = CW.Get_rank()
    size = CW.Get_size()
    args = parse_inputs()
    prev_args = args
    print("Starting mosaic_mod_script on rank", rank)

    # Read in directories:
    input_file = args.input_file
    #save_dir = args.save_directory
    #if os.path.exists(save_dir) == False:
    #    os.makedirs(save_dir)

    # Read in input file
    print("Reading in input mosaic file on rank", rank)
    positions = []
    paths = []
    args_dict = []
    with open(input_file, 'rU') as mosaic_file:
        reader = csv.reader(mosaic_file)
        for row in reader:
            if row[0] == 'Grid_inputs:':
                glr = float(row[1])
                grl = float(row[2])
                glw = float(row[3])
                ghspace = float(row[4])
            elif row[0][0] != '#':
                positions.append((int(row[0]), int(row[1])))
                paths.append(row[2])
                dict = ""
                for col in row[3:]:
                    dict = dict + col
                    if col != row[-1]:
                        dict = dict + ','
                dict = ast.literal_eval(dict)
                args_temp = argparse.Namespace(**vars(args))
                for key in list(dict.keys()):
                    if key in args:
                        exec("args_temp."+ key + " = " + "str(dict[key])")
                args_dict.append(args_temp)
                del args_temp
                args = prev_args
                
    import pdb
    pdb.set_trace()

    positions = np.array(positions)

    c = define_constants()
    mym.set_global_font_size(args.text_font)
    files = []
    simfo = []
    X = []
    Y = []
    X_vel = []
    Y_vel = []
    sim_files = []
    L = None
    for pit in range(len(paths)):
        fs = get_files(paths[pit], args_dict[pit])
        files.append(fs)

        #print "paths =", paths
        #print "fs =", fs
        #print "args_dict =", args_dict
        sfo = sim_info(paths[pit], fs[-1], args_dict[pit])
        simfo.append(sfo)

        if args_dict[pit].yt_proj == False:
            x, y, x_vel, y_vel, cl = mym.initialise_grid(files[pit][-1], zoom_times=args_dict[pit].zoom_times)
            X.append(x)
            Y.append(y)
            X_vel.append(x_vel)
            Y_vel.append(y_vel)
        else:
            x = np.linspace(sfo['xmin'], sfo['xmax'], sfo['dimension'])
            y = np.linspace(sfo['ymin'], sfo['ymax'], sfo['dimension'])
            x, y  = np.meshgrid(x, y)
            
            annotate_space = (simfo[pit]['xmax'] - simfo[pit]['xmin'])/31.
            x_ind = []
            y_ind = []
            counter = 0
            while counter < 31:
                val = annotate_space*counter + annotate_space/2. + simfo[pit]['xmin']
                x_ind.append(int(val))
                y_ind.append(int(val))
                counter = counter + 1
            x_vel, y_vel = np.meshgrid(x_ind, y_ind)
            if args_dict[pit].projection_orientation != None:
                y_val = 1./np.tan(np.deg2rad(float(args_dict[pit].projection_orientation)))
                if np.isinf(y_val):
                    y_val = 0.0
                L = [1.0, y_val, 0.0]
            else:
                if has_particles == False or len(dd['particle_posx']) == 1:
                    L = [0.0, 1.0, 0.0]
                else:
                    pos_vec = [np.diff(dd['particle_posx'].value)[0], np.diff(dd['particle_posy'].value)[0]]
                    L = [-1*pos_vec[-1], pos_vec[0]]
                    L.append(0.0)
                    if L[0] > 0.0:
                        L = [-1.0*L[0], -1.0*L[1], 0.0]
            print("SET PROJECTION ORIENTATION L=", L)
            L = np.array(L)
            X.append(x)
            Y.append(y)
            X_vel.append(x_vel)
            Y_vel.append(y_vel)
        if rank == 0:
            print("shape of x, y", np.shape(x), np.shape(y))

        if args_dict[pit].yt_proj == False and args_dict[pit].image_center != 0:
            sim_fs = sorted(glob.glob(paths[pit] + 'WIND_hdf5_plt_cnt*'))
        elif args_dict[pit].yt_proj != False and args_dict[pit].image_center != 0:
            sim_fs = files
        else:
            sim_fs = []
        sim_files.append(sim_fs)
    #myf.set_normal(L)
    #print "SET PROJECTION ORIENTATION L=", myf.get_normal()

    # Initialise Grid and build lists
    if args.plot_time != None:
        m_times = [args.plot_time]
    else:
        m_times = mym.generate_frame_times(files[0], args.time_step, presink_frames=args.presink_frames, end_time=args.end_time)
    no_frames = len(m_times)
    m_times = m_times[args.start_frame:]
    sys.stdout.flush()
    CW.Barrier()

    usable_files = []
    usable_sim_files = []
    for pit in range(len(paths)):
        usable_fs = mym.find_files(m_times, files[pit])
        usable_files.append(usable_fs)
        if args_dict[pit].image_center != 0 and args_dict[pit].yt_proj == False:
            usable_sfs = mym.find_files(m_times, sim_files[pit])
            usable_sim_files.append(usable_fs)
            del sim_files[pit]
        else:
            usable_sim_files.append([])
    sys.stdout.flush()
    CW.Barrier()
    frames = list(range(args.start_frame, no_frames))

    sink_form_time = []
    for pit in range(len(paths)):
        sink_form = mym.find_sink_formation_time(files[pit])
        print("sink_form_time", sink_form_time)
        sink_form_time.append(sink_form)
    del files

    # Define colourbar bounds
    cbar_max = args.colourbar_max
    cbar_min = args.colourbar_min

    if L is None:
        if args.axis == 'xy':
            L = [0.0, 0.0, 1.0]
        else:
            L = [1.0, 0.0, 0.0]
        L = np.array(L)
    if args.axis == 'xy':
        y_int = 1
    else:
        y_int = 2

    sys.stdout.flush()
    CW.Barrier()
    rit = args.working_rank
    for frame_val in range(len(frames)):
        if rank == rit:
            time_val = m_times[frame_val]
            plt.clf()
            columns = np.max(positions[:,0])
            rows = np.max(positions[:,1])

            width = float(columns)*(14.5/3.)
            height = float(rows)*(17./4.)
            fig =plt.figure(figsize=(width, height))
            
            gs_left = gridspec.GridSpec(rows, columns-1)
            gs_right = gridspec.GridSpec(rows, 1)

            gs_left.update(right=glr, wspace=glw, hspace=ghspace)
            gs_right.update(left=grl, hspace=ghspace)
            
            axes_dict = {}
            counter = 1

            for pit in range(len(paths)):
                
                try:
                    title_parts = args_dict[pit].title
                except:
                    title_parts = args_dict[pit]['title']
                title = ''
                for part in title_parts:
                    if part != title_parts[-1]:
                        title = title + part + ' '
                    else:
                        title = title + part
            
                ax_label = 'ax' + str(counter)
                yit = np.where(positions[:,1] == positions[pit][1])[0][0]
                if positions[pit][0] == 1 and positions[pit][1] == 1:
                    if columns > 1:
                        axes_dict.update({ax_label:fig.add_subplot(gs_left[0,0])})
                        #print "ADDED SUBPLOT:", counter, "on rank", rank
                    else:
                        axes_dict.update({ax_label:fig.add_subplot(gs_right[0,0])})
                        #print "ADDED SUBPLOT:", counter, "on rank", rank
                elif positions[pit][0] != columns:
                    if args.share_x and args.share_y:
                        if yit >= len(axes_dict):
                            axes_dict.update({ax_label:fig.add_subplot(gs_left[positions[pit][1]-1,positions[pit][0]-1], sharex=axes_dict['ax1'])})
                            #print "ADDED SUBPLOT:", counter, "on rank", rank
                        else:
                            axes_dict.update({ax_label:fig.add_subplot(gs_left[positions[pit][1]-1,positions[pit][0]-1], sharex=axes_dict['ax1'], sharey=axes_dict[list(axes_dict.keys())[yit]])})
                            #print "ADDED SUBPLOT:", counter, "on rank", rank
                    elif args.share_x:
                        axes_dict.update({ax_label:fig.add_subplot(gs_left[positions[it][1]-1,positions[pit][0]-1], sharex=axes_dict['ax1'])})
                        #print "ADDED SUBPLOT:", counter, "on rank", rank
                    elif args.share_y and positions[pit][0]!=1:
                        yit = np.where(positions[:,1] == positions[pit][1])[0][0]
                        axes_dict.update({ax_label:fig.add_subplot(gs_left[positions[pit][1]-1,positions[pit][0]-1], sharey=axes_dict[list(axes_dict.keys())[yit]])})
                        #print "ADDED SUBPLOT:", counter, "on rank", rank
                    elif args.share_y:
                        axes_dict.update({ax_label:fig.add_subplot(gs_left[positions[pit][1]-1,positions[pit][0]-1])})
                        #print "ADDED SUBPLOT:", counter, "on rank", rank
                    else:
                        axes_dict.update({ax_label:fig.add_subplot(gs_left[positions[pit][1]-1,positions[pit][0]-1])})
                        #print "ADDED SUBPLOT:", counter, "on rank", rank
                else:
                    if args.share_x and args.share_y:
                        yit = np.where(positions[:,1] == positions[pit][1])[0][0]
                        axes_dict.update({ax_label:fig.add_subplot(gs_right[positions[pit][1]-1,0], sharex=axes_dict['ax1'], sharey=axes_dict[list(axes_dict.keys())[yit]])})
                        #print "ADDED SUBPLOT:", counter, "on rank", rank
                    elif args.share_x:
                        axes_dict.update({ax_label:fig.add_subplot(gs_right[positions[pit][1]-1,0], sharex=axes_dict['ax1'])})
                        #print "ADDED SUBPLOT:", counter, "on rank", rank
                    elif args.share_y:
                        yit = np.where(positions[:,1] == positions[pit][1])[0][0]
                        axes_dict.update({ax_label:fig.add_subplot(gs_right[positions[pit][1]-1,0], sharey=axes_dict[list(axes_dict.keys())[yit]])})
                        #print "ADDED SUBPLOT:", counter, "on rank", rank
                    else:
                        axes_dict.update({ax_label:fig.add_subplot(gs_right[positions[pit][1]-1,0])})
                        #print "ADDED SUBPLOT:", counter, "on rank", rank

                counter = counter + 1
                axes_dict[ax_label].set(adjustable='box-forced', aspect='equal')
                

                if args.yt_proj and args.plot_time==None and os.path.isfile(paths[pit] + "movie_frame_" + ("%06d" % frames[frame_val]) + ".pkl"):
                    pickle_file = paths[pit] + "movie_frame_" + ("%06d" % frames[frame_val]) + ".pkl"
                    print("USING PICKLED FILE:", pickle_file)
                    file = open(pickle_file, 'r')
                    #weight_fieldstuff = pickle.load(file)
                    X[pit], Y[pit], image, magx, magy, X_vel[pit], Y_vel[pit], velx, vely, part_info, args_dict[pit], simfo[pit] = pickle.load(file)

                    #file_time = stuff[17]
                    file.close()

                else:
                    time_val = m_times[frame_val]
                    print("FILE =", usable_files[pit][frame_val])
                    has_particles = has_sinks(usable_files[pit][frame_val])
                    if has_particles:
                        part_info = mym.get_particle_data(usable_files[pit][frame_val], args_dict[pit].axis, proj_or=L)
                    else:
                        part_info = {}
                    center_vel = [0.0, 0.0, 0.0]
                    if args.image_center != 0 and has_particles:
                        original_positions = [X[pit], Y[pit], X_vel[pit], y_vel[pit]]
                        x_pos = np.round(part_info['particle_position'][0][args.image_center - 1]/cl)*cl
                        y_pos = np.round(part_info['particle_position'][1][args.image_center - 1]/cl)*cl
                        pos = np.array([part_info['particle_position'][0][args.image_center - 1], part_info['particle_position'][1][args.image_center - 1]])
                        X[pit] = X[pit] + x_pos
                        Y[pit] = Y[pit] + y_pos
                        X_vel[pit] = X_vel[pit] + x_pos
                        Y_vel[pit] = Y_vel[pit] + y_pos
                        if args.yt_proj == False:
                            sim_file = usable_sim_files[frame_val][:-12] + 'part' + usable_sim_files[frame_val][-5:]
                        else:
                            sim_file = part_file
                        if len(part_info['particle_mass']) == 1:
                            part_ind = 0
                        else:
                            min_dist = 1000.0
                            for part in range(len(part_info['particle_mass'])):
                                f = h5py.File(sim_file, 'r')
                                temp_pos = np.array([f[list(f.keys())[11]][part][13]/c['au'], f[list(f.keys())[11]][part][13+y_int]/c['au']])
                                f.close()
                                dist = np.sqrt(np.abs(np.diff((temp_pos - pos)**2)))[0]
                                if dist < min_dist:
                                    min_dist = dist
                                    part_ind = part
                        f = h5py.File(sim_file, 'r')
                        center_vel = [f[list(f.keys())[11]][part_ind][18], f[list(f.keys())[11]][part_ind][19], f[list(f.keys())[11]][part_ind][20]]
                        f.close()
                    xabel, yabel, xlim, ylim = image_properties(X[pit], Y[pit], args_dict[pit], simfo[pit])
                    if args_dict[pit].axis == 'xy':
                        center_vel=center_vel[:2]
                    else:
                        center_vel=center_vel[::2]
                    
                    if args_dict[pit].ax_lim != None:
                        if has_particles and args_dict[pit].image_center != 0:
                            xlim = [-1*args_dict[pit].ax_lim + part_info['particle_position'][0][args_dict[pit].image_center - 1], args_dict[pit].ax_lim + part_info['particle_position'][0][args_dict[pit].image_center - 1]]
                            ylim = [-1*args_dict[pit].ax_lim + part_info['particle_position'][1][args_dict[pit].image_center - 1], args_dict[pit].ax_lim + part_info['particle_position'][1][args_dict[pit].image_center - 1]]
                        else:
                            xlim = [-1*args_dict[pit].ax_lim, args_dict[pit].ax_lim]
                            ylim = [-1*args_dict[pit].ax_lim, args_dict[pit].ax_lim]

                    if args.yt_proj == False:
                        f = h5py.File(usable_files[pit][frame_val], 'r')
                        image = get_image_arrays(f, simfo[pit]['field'], simfo[pit], args_dict[pit], X[pit], Y[pit])
                        magx = get_image_arrays(f, 'mag'+args.axis[0]+'_'+simfo[pit]['movie_file_type']+'_'+args.axis, simfo[pit], args_dict[pit], X[pit], Y[pit])
                        magy = get_image_arrays(f, 'mag'+args.axis[1]+'_'+simfo[pit]['movie_file_type']+'_'+args.axis, simfo[pit], args_dict[pit], X[pit], Y[pit])
                        x_pos_min = int(np.round(np.min(X[pit]) - simfo[pit]['xmin_full'])/simfo[pit]['cell_length'])
                        y_pos_min = int(np.round(np.min(Y[pit]) - simfo[pit]['xmin_full'])/simfo[pit]['cell_length'])
                        if np.shape(f['vel'+args.axis[0]+'_'+simfo[pit]['movie_file_type']+'_'+args.axis]) == (2048, 2048):
                            velocity_data = [f['vel'+args.axis[0]+'_'+simfo[pit]['movie_file_type']+'_'+args.axis], f['vel'+args.axis[1]+'_'+simfo[pit]['movie_file_type']+'_'+args.axis]]
                        elif args.axis == 'xy':
                            velocity_data = [f['vel'+args.axis[0]+'_'+simfo[pit]['movie_file_type']+'_'+args.axis][:,:,0], f['vel'+args.axis[1]+'_'+simfo[pit]['movie_file_type']+'_'+args.axis][:,:,0]]
                        else:
                            velocity_data = [f['vel'+args.axis[0]+'_'+simfo[pit]['movie_file_type']+'_'+args.axis][:,0,:], f['vel'+args.axis[1]+'_'+simfo[pit]['movie_file_type']+'_'+args.axis][:,0,:]]
                        velx, vely = mym.get_quiver_arrays(y_pos_min, x_pos_min, X[pit], velocity_data[0], velocity_data[1], center_vel=center_vel)
                    else:
                        if args_dict[pit].image_center == 0 or has_particles == False:
                            center_pos = np.array([0.0, 0.0, 0.0])
                        else:
                            dd = f.all_data()
                            center_pos = np.array([dd['particle_posx'][args.image_center-1].in_units('AU'), dd['particle_posy'][args.image_center-1].in_units('AU'), dd['particle_posz'][args.image_center-1].in_units('AU')])
                        x_width = (xlim[1] -xlim[0])
                        y_width = (ylim[1] -ylim[0])
                        thickness = yt.YTArray(args.slice_thickness, 'AU')
                        
                        proj = yt.OffAxisProjectionPlot(f, L, [simfo[pit]['field'], 'cell_mass', 'velz_mw', 'magz_mw', 'Projected_Magnetic_Field_mw', 'Projected_Velocity_mw'], center=(center_pos, 'AU'), width=(x_width, 'AU'), depth=(args.slice_thickness, 'AU'))
                        image = (proj.frb.data[simfo[pit]['field']]/thickness.in_units('cm')).value
                        velx_full = (proj.frb.data[('gas', 'Projected_Velocity_mw')].in_units('g*cm**2/s')/thickness.in_units('cm')).value
                        vely_full = (proj.frb.data[('gas', 'velz_mw')].in_units('g*cm**2/s')/thickness.in_units('cm')).value
                        magx = (proj.frb.data[('gas', 'Projected_Magnetic_Field_mw')].in_units('g*gauss*cm')/thickness.in_units('cm')).value
                        magy = (proj.frb.data[('gas', 'magz_mw')].in_units('g*gauss*cm')/thickness.in_units('cm')).value
                        mass = (proj.frb.data[('gas', 'cell_mass')].in_units('cm*g')/thickness.in_units('cm')).value
                        
                        velx_full = velx_full/mass
                        vely_full = vely_full/mass
                        magx = magx/mass
                        magy = magy/mass
                        del mass

                        velx, vely = mym.get_quiver_arrays(0.0, 0.0, X[pit], velx_full, vely_full, center_vel=center_vel)
                        del velx_full
                        del vely_full

                        if len(frames) == 1:
                            if rank == 0:
                                pickle_file = paths[pit] + "movie_frame_" + ("%06d" % frames[frame_val]) + ".pkl"
                                file = open(pickle_file, 'w+')
                                pickle.dump((X[pit], Y[pit], image, magx, magy, X_vel[pit], Y_vel[pit], velx, vely, xlim, ylim, has_particles, part_info, simfo[pit], time_val,xabel, yabel), file)
                                file.close()
                                print("Created Pickle:", pickle_file, "for  file:", usable_files[pit][frame_val])
                        else:
                            pickle_file = paths[pit] + "movie_frame_" + ("%06d" % frames[frame_val]) + ".pkl"
                            file = open(pickle_file, 'w+')
                            pickle.dump((X[pit], Y[pit], image, magx, magy, X_vel[pit], Y_vel[pit], velx, vely, xlim, ylim, has_particles, part_info, simfo[pit], time_val,xabel, yabel), file)
                            file.close()
                            print("Created Pickle:", pickle_file, "for  file:", usable_files[pit][frame_val])
                    
                    f.close()

                plot = axes_dict[ax_label].pcolormesh(X[pit], Y[pit], image, cmap=plt.cm.gist_heat, norm=LogNorm(vmin=cbar_min, vmax=cbar_max), rasterized=True)
                plt.gca().set_aspect('equal')
                if frame_val > 0 or time_val > -1.0:
                    axes_dict[ax_label].streamplot(X[pit], Y[pit], magx, magy, density=4, linewidth=0.25, arrowstyle='-', minlength=0.5)
                else:
                    axes_dict[ax_label].streamplot(X[pit], Y[pit], magx, magy, density=4, linewidth=0.25, minlength=0.5)

                xlim = args_dict[pit]['xlim']
                ylim = args_dict[pit]['ylim']
                mym.my_own_quiver_function(axes_dict[ax_label], X_vel[pit], Y_vel[pit], velx, vely, plot_velocity_legend=bool(args_dict[pit]['annotate_velocity']), limits=[xlim, ylim], standard_vel=args.standard_vel)
                if args_dict[pit]['has_particles']:
                    if args.annotate_particles_mass == True:
                        mym.annotate_particles(axes_dict[ax_label], part_info['particle_position'], part_info['accretion_rad'], limits=[xlim, ylim], annotate_field=part_info['particle_mass'])
                    else:
                        mym.annotate_particles(axes_dict[ax_label], part_info['particle_position'], part_info['accretion_rad'], limits=[xlim, ylim], annotate_field=None)
                if args.plot_lref == True:
                    r_acc = np.round(part_info['accretion_rad'])
                    axes_dict[ax_label].annotate('$r_{acc}$='+str(r_acc)+'AU', xy=(0.98*simfo[pit]['xmax'], 0.93*simfo[pit]['ymax']), va="center", ha="right", color='w', fontsize=args_dict[pit].text_font)
                if args.annotate_time == "True" and pit == 0:
                    print("ANNONTATING TIME:", str(int(time_val))+'yr')
                    time_text = axes_dict[ax_label].text((xlim[0]+0.01*(xlim[1]-xlim[0])), (ylim[1]-0.03*(ylim[1]-ylim[0])), '$t$='+str(int(time_val))+'yr', va="center", ha="left", color='w', fontsize=args.text_font)
                    time_text.set_path_effects([path_effects.Stroke(linewidth=3, foreground='black'), path_effects.Normal()])
                    #ax.annotate('$t$='+str(int(time_val))+'yr', xy=(xlim[0]+0.01*(xlim[1]-xlim[0]), ylim[1]-0.03*(ylim[1]-ylim[0])), va="center", ha="left", color='w', fontsize=args.text_font)
                title_text = axes_dict[ax_label].text((np.mean(xlim)), (ylim[1]-0.03*(ylim[1]-ylim[0])), title, va="center", ha="center", color='w', fontsize=(args.text_font+2))
                title_text.set_path_effects([path_effects.Stroke(linewidth=3, foreground='black'), path_effects.Normal()])

                if positions[pit][0] == columns:
                    cbar = plt.colorbar(plot, pad=0.0, ax=axes_dict[ax_label])
                    cbar.set_label('Density (gcm$^{-3}$)', rotation=270, labelpad=14, size=args.text_font)
                axes_dict[ax_label].set_xlabel(args_dict[pit]['xabel'], labelpad=-1, fontsize=args.text_font)
                if positions[pit][0] == 1:
                    axes_dict[ax_label].set_ylabel(args_dict[pit]['yabel'], labelpad=-20, fontsize=args.text_font)
                axes_dict[ax_label].set_xlim(xlim)
                axes_dict[ax_label].set_ylim(ylim)
                for line in axes_dict[ax_label].xaxis.get_ticklines():
                    line.set_color('white')
                for line in axes_dict[ax_label].yaxis.get_ticklines():
                    line.set_color('white')

                plt.tick_params(axis='both', which='major', labelsize=16)
                for line in axes_dict[ax_label].xaxis.get_ticklines():
                    line.set_color('white')
                for line in axes_dict[ax_label].yaxis.get_ticklines():
                    line.set_color('white')

                if positions[pit][0] != 1:
                    yticklabels = axes_dict[ax_label].get_yticklabels()
                    plt.setp(yticklabels, visible=False)

                if positions[pit][0] == 1:
                    axes_dict[ax_label].tick_params(axis='y', which='major', labelsize=args.text_font)
                if positions[pit][1] == rows:
                    axes_dict[ax_label].tick_params(axis='x', which='major', labelsize=args.text_font)
                    if positions[pit][0] != 1:
                        xticklabels = axes_dict[ax_label].get_xticklabels()
                        plt.setp(xticklabels[0], visible=False)

                if len(usable_files[pit]) > 1:
                    if args.output_filename == None:
                        import pdb
                        pdb.set_trace()
                        file_name = save_dir + "movie_frame_" + ("%06d" % frames[frame_val])
                    else:
                        file_name = args.output_filename + "_" + str(int(time_val))
                else:
                    if args.output_filename != None:
                        file_name = args.output_filename
                    else:
                        import pdb
                        pdb.set_trace()
                        file_name = save_dir + "time_" + str(args.plot_time)

                plt.savefig(file_name + ".eps", format='eps', bbox_inches='tight')
                #plt.savefig(file_name + ".pdf", format='pdf', bbox_inches='tight')
                
                #plt.savefig(file_name + ".jpg", format='jpeg', bbox_inches='tight')
                call(['convert', '-antialias', '-quality', '100', '-density', '200', '-resize', '100%', '-flatten', file_name+'.eps', file_name+'.jpg'])
                os.remove(file_name + '.eps')

                del image
                del magx
                del magy
                del velx
                del vely
                
                if args.image_center != 0 and has_particles:
                    X[pit], Y[pit], X_vel[pit], Y_vel[pit] = original_positions
            print('Created frame', (frames[frame_val]), 'of', str(frames[-1]), 'on rank', rank, 'at time of', str(time_val), 'to save_dir:', file_name + '.eps')

        rit = rit +1
        if rit == size:
            rit = 0

    print("completed making movie frames on rank", rank)
Esempio n. 9
0
                myf.set_north_vector(north_unit)
                
                #These numbers are just creative pickle file management
                div_32 = int(rank/32)
                rem_32 = np.remainder(rank,32)
                div_4 = int(rem_32/4)
                proj_root_rank = div_32*32 + div_4*4

                #Fields to project
                field_list = [field, ('gas', 'Radial_Velocity'), ('gas', 'Proj_x_velocity'), ('gas', 'Proj_y_velocity')]
                proj_dict = {field[1]:[], 'Radial_Velocity':[], 'Proj_x_velocity':[], 'Proj_y_velocity':[]}
                
                proj_dict_keys = str(proj_dict.keys()).split("['")[1].split("']")[0].split("', '")
                for field in yt.parallel_objects(field_list):
                    print("Calculating projection with normal", proj_vector_unit, "for field", field, "on rank", rank)
                    proj = yt.OffAxisProjectionPlot(ds, proj_vector_unit, field, width=(2*args.ax_lim, 'AU'), weight_field=weight_field, method='integrate', center=(center_pos.value, 'AU'), depth=(args.slice_thickness, 'AU'), north_vector=north_unit)
                    
                    if args.resolution != 800:
                        proj.set_buff_size([args.resolution, args.resolution])
                
                    if args.field in str(field):
                        if weight_field == None:
                            if args.divide_by_proj_thickness == "True":
                                proj_array = np.array((proj.frb.data[field]/thickness.in_units('cm')).in_units(args.field_unit))
                            else:
                                proj_array = np.array(proj.frb.data[field].in_units(args.field_unit+"*cm"))
                        else:
                            if args.divide_by_proj_thickness == "True":
                                proj_array = np.array(proj.frb.data[field].in_units(args.field_unit))
                            else:
                                proj_array = np.array(proj.frb.data[field].in_units(args.field_unit)*thickness.in_units('cm'))
Esempio n. 10
0
def make_two_plots(file_name, ds, sph, center, L, Lx, z, sfr, zoom=70):
    fig = plt.figure()

    # creates grid on which the figure will be plotted
    grid = AxesGrid(fig, (.5, .5, 1.5, 1.5),
                    nrows_ncols=(1, 3),
                    axes_pad=.8,
                    label_mode="L",
                    cbar_location="right",
                    cbar_mode="each",
                    cbar_set_cax=False,
                    cbar_size="3%",
                    cbar_pad="0%",
                    direction='column')

    sns.set_style("whitegrid", {'axes.grid': False})

    # creates density projection plot
    pro = yt.OffAxisProjectionPlot(ds,
                                   Lx, ('gas', 'density'),
                                   center=center,
                                   width=(zoom, 'kpc'),
                                   north_vector=L)
    pro.set_font({'size': 10})
    pro.hide_axes()
    pro.set_unit(('gas', 'density'), 'Msun/pc**2')
    pro.set_zlim(('gas', 'density'), 2, 500)
    cmap = sns.blend_palette(("black", "#984ea3", "#d73027", "darkorange",
                              "#ffe34d", "#4daf4a", "white"),
                             n_colors=60,
                             as_cmap=True)
    pro.set_cmap(("gas", "density"), cmap)
    pro.annotate_scale(size_bar_args={'color': 'white'})

    # draws density axes onto main figure
    pro.plots[('gas', 'density')].figure = fig
    pro.plots[('gas', 'density')].axes = grid[0].axes
    pro.plots[('gas', 'density')].cax = grid.cbar_axes[0]
    pro._setup_plots()

    # creates metallicity plot
    oap = yt.SlicePlot(ds,
                       Lx, ("gas", "metallicity"),
                       center=center,
                       width=(zoom, 'kpc'),
                       north_vector=L,
                       data_source=sph)
    oap.annotate_cquiver('r_flux', 'cyl_flux', 16)
    oap.set_font({'size': 10})
    oap.hide_axes()
    oap.set_zlim(("gas", "metallicity"), 2e-4, 10.5)
    cmap = sns.blend_palette(
        ("black", "#984ea3", "#4575b4", "#4daf4a", "#ffe34d", "darkorange"),
        as_cmap=True)
    oap.set_cmap(("gas", "metallicity"), cmap)

    # draws metallicity axes onto main figure
    oap.plots[("gas", "metallicity")].figure = fig
    oap.plots[("gas", "metallicity")].axes = grid[1].axes
    oap.plots[("gas", "metallicity")].cax = grid.cbar_axes[1]
    oap._setup_plots()
    oap.annotate_scale()

    # creates temperature plot
    tp = yt.SlicePlot(ds,
                      Lx, ("gas", "temperature"),
                      center=center,
                      width=(zoom, 'kpc'),
                      north_vector=L,
                      data_source=sph)
    tp.set_font({'size': 10})
    tp.hide_axes()
    tp.set_zlim(("gas", "temperature"), 400, 8e7)
    cmap = sns.blend_palette(("black", "#d73027", "darkorange", "#ffe34d"),
                             n_colors=50,
                             as_cmap=True)
    tp.set_cmap(("gas", "temperature"), cmap)

    # draws axes onto main figure
    tp.plots[('gas', 'temperature')].figure = fig
    tp.plots[('gas', 'temperature')].axes = grid[2].axes
    tp.plots[('gas', 'temperature')].cax = grid.cbar_axes[2]
    tp._setup_plots()
    tp.annotate_scale()

    # creates SFH plot below the first three panels on main figure
    ax3 = fig.add_subplot(211, position=[.5, .8, 1.5, .17])

    ax3.plot(z, sfr)
    ax3.set_xlim(1.05, .45)
    ax3.set_ylim(-2, 33)
    ax3.set_xlabel('Redshift')
    ax3.set_ylabel(r'SFR (M$_\odot$/yr)')

    # saves the figure
    plt.savefig(plot_dir + file_name[-6:] + '_plot.png', bbox_inches='tight')
Esempio n. 11
0
def ks_law_new(gas_cylinder,
               star_cylinder,
               ytsnap,
               star_center,
               image_width=800,
               field="stars",
               n_bins=100,
               r_max=None):
    # hard coded things
    # really should be some sort of config
    width = ytsnap.arr(80, "kpc")
    depth = ytsnap.arr(20, "kpc")

    age = ytsnap.arr(
        3, "Gyr"
    )  # increasing this increases the number of resolved points (reduces the "artificial lines")
    max_age = age.in_units("yr").value

    weight_field = None

    if r_max == None:
        r_max = gas_cylinder.ds.arr(
            15,
            "kpc")  # reducing this reduces the width of the artificial lines

    print "Computing KS law"
    print "ks 1) Gathering initial projections"

    # todo.. make the output a histogram

    width = ((width.v, width.units), (width.v, width.units))
    image_width = (image_width, image_width)

    # slice of the star particle mass
    normal = gas_cylinder.get_field_parameter("normal")
    data_source = gen_data_source(0, gas_cylinder, ytsnap, width, width[0],
                                  "kpc")

    plot = yt.OffAxisProjectionPlot(data_source.ds,
                                    normal, [
                                        ("gas", "density"),
                                    ],
                                    center=gas_cylinder.center,
                                    width=width,
                                    depth=width[0],
                                    weight_field=weight_field,
                                    north_vector=normal)
    # set the units
    #	plot.set_axes_unit("kpc")
    #	plot.set_unit(("gas","density"),"Msun/pc**2")
    #	plot.set_zlim(("gas","density"),10**(-3),10**4)
    images = plot.frb
    gas_image = images[('gas', 'density')]

    #	plot = yt.ProjectionPlot(star_cylinder,2,[("deposit","io_cic"),("deposit","io_density")],center=star_center,width=width[0],weight_field=None)
    #	# set the units
    #	plot.set_axes_unit("kpc")
    #	plot.set_unit(("deposit","io_cic"),"Msun/pc**2")
    #	plot.set_zlim(("deposit","io_cic"),10**(-3),10**8)
    ##       plot.save("test_star_cylinder")
    #	plot.annotate_particles(width=(20,"kpc"), p_size=2.0, col='k', marker='o', stride=1.0, ptype="io", minimum_mass=None, alpha=1.0)
    #	plot.save("test_star_cylinder_stars")
    ##	plot.set_unit(("deposit","io_cic"),"Msun/kpc**2")
    #	plot.set_unit(("deposit","io_density"),"Msun/kpc**2")
    images = plot.frb
    star_image = images[('deposit', 'io_cic')]

    print "ks 2) collecting data"

    # filtering things withing the first 15 kpc does the trick usuall

    n_bins = gas_image.shape[0]
    bin_width = width[0][0] / float(n_bins)
    left = -width[0][0] / 2.0
    right = width[0][0] / 2.0
    length = np.arange(left, right, bin_width)
    shift_thing = (length[2] - length[1]) / 2.0
    length = length + shift_thing

    # to make radial array, we need to calculate the radial distance of each pixel in a 2d array. Namely.

    length_squared = np.power(length, 2.0)  # power of 2
    radius_squared = np.add.outer(length_squared,
                                  length_squared)  # adds as an outer product
    radius = np.sqrt(radius_squared)

    # now we need to generate some radial bins

    radius_bin_width = radius.max() / float(n_bins)
    radius_bins = np.arange(0, radius.max(), radius_bin_width)

    # flatten the arrays

    flat_radius = radius.flatten()

    # now need to find the indices in which radius is maximum
    radius_filter = (flat_radius < r_max.in_units("kpc").v)

    # ok, now get the data
    gas_density = np.array(gas_image.in_units("Msun/pc**2").value).flatten()
    star_density = np.array(star_image.in_units("Msun/kpc**2").value).flatten()

    # filter by radius
    gas_density = np.log10(gas_density[radius_filter])
    star_density = np.log10(star_density[radius_filter] / max_age)

    return gas_density, star_density
Esempio n. 12
0
def plot_face_edge(isnap=28, selected_field='h2density', sizekpc=7., cutLow=1.e-5, f_out=None, save_plot=True,
                   overplot_clumps=True, incut=6., field_cut='h2density', n_cell_min=8, largeNum=1.e+42,
                   data=None, ds=None, dd=None, leaf_clumps=None, plotClumpID=False
                   ):
    """

    Parameters
    ----------
      isnap: int
             number of snaphot select for the plot

      selected_field: str
             field for image plot

      sizekpc: float
              size of the extracted region from isnap

      cutLow: float
              dynamical range for the colorbar

      f_out: str
             output filename

      save_plot: bool
             print to file

      overplot_clumps: bool
             overplot clumps

      incut: float
             threshold for clump definition

      field_cut: str
             field out of which clumps are identified

      n_cell_min: int
             min cell to define a clump

      largeNum: float
             BS to have YT to collaborate

    """

    if f_out is None:
        f_out = 'test_' + selected_field + '_out' + \
            str(isnap) + '_yt_unit_plot.png'

    from mpl_toolkits.axes_grid import AxesGrid
    import matplotlib.pyplot as plt

    # prepare input data, if not already passed as arguments
    #
    if data is None:
        data = import_fetch_gal(isnap=isnap)
    #
    if (ds is None) and (dd is None):
        ds, dd = prepare_unigrid(data=data,
                                 add_unit=True,
                                 regionsize_kpc=sizekpc,
                                 debug=False)
    #
    if overplot_clumps and (leaf_clumps is None):
        __, leaf_clumps = ytclumpfind_H2(ds, dd, field_cut, incut,
                                         c_max=None, step=1e+6,
                                         N_cell_min=n_cell_min, save=False,
                                         plot=False, saveplot=None, fold_out='./')

    if overplot_clumps:
        id_sorted = sorted(range(len(leaf_clumps)),
                           key=lambda x: np.sum(leaf_clumps[x]["density"]))

    # compute inertia tensor and it principal axes
    e_value, e_vectors = calculate_eigenvektoren(data=data, sizekpc=sizekpc)
    # references
    # los_vec = e_vectors[0,:] # face on
    # los_vec = e_vectors[1,:] # edge on
    # los_vec = e_vectors[2,:] # face on (again, perpendicular direction)
    #
    # set the camera axes along the principal axes of the inertia tensor
    vec_list = [e_vectors[2, :], e_vectors[1, :]]
    up_list = [e_vectors[0, :], e_vectors[0, :]]

    # setup the plot
    #
    fig = plt.figure()
    #
    grid = AxesGrid(fig, (0.075, 0.075, 0.85, 0.85),
                    nrows_ncols=(1, 2),
                    axes_pad=0.7,
                    #label_mode = "L",
                    label_mode="1",
                    share_all=True,
                    cbar_location="right",
                    cbar_mode="single",
                    cbar_size="3%",
                    cbar_pad=0.03)

    # plot edge on and face on
    #
    for iplot, los_vec, up_vec in zip(xrange(2), vec_list, up_list):

        prj = yt.OffAxisProjectionPlot(ds=ds, center=[0, 0, 0], normal=los_vec, fields=selected_field, width=(4, 'kpc'), north_vector=up_vec, weight_field='density',
                                       )

        prj.set_cmap(field=selected_field, cmap='inferno')

        if selected_field == 'density':
            selected_unit = 'Msun/pc**3'
        elif selected_field == 'h2density':
            selected_unit = '1/cm**3'
        else:
            raise TypeError('unit not implemented for the field')
        prj.set_unit(selected_field, selected_unit)
        prj.set_zlim(selected_field, cutLow * dd[selected_field].max().to(selected_unit),
                     dd[selected_field].max().to(selected_unit))

        # annotate clump if asked for
        if overplot_clumps:
            prj.annotate_contour(field=field_cut, ncont=1, factor=1,
                                 clim=(incut, largeNum), plot_args={
                                     'colors': 'white'}
                                 )  # to deal w/ stupid yt annotate_clump() bug

            for ileaf in id_sorted:
                _fc = np.mean(leaf_clumps[ileaf].data.fcoords[:], axis=0)

                if plotClumpID:
                    prj.annotate_marker(_fc,
                                        coord_system='data',
                                        plot_args={'color': 'red', 's': 50})
                    prj.annotate_text(_fc,
                                      ileaf + 1,
                                      coord_system='data',
                                      text_args={'color': 'red', 'size': 25},
                                      inset_box_args={'boxstyle': 'square',
                                                      'facecolor': 'white',
                                                      'linewidth': 1.0,
                                                      'edgecolor': 'white',
                                                      'alpha': 0.})

        prj.set_ylabel('kpc')
        prj.set_xlabel('kpc')
        prj.set_font({'family': 'Times',  # 'style': 'italic',
                          # 'weight': 'bold',
                          'size': 26})

        # set the plot into the grid
        plot = prj.plots[selected_field]
        plot.figure = fig
        plot.axes = grid[iplot].axes
        plot.cax = grid.cbar_axes[iplot]

        if iplot == 1:
            plot.axes.set_axis_off()
            # plot.axes.set_label('')
            # plot.axes.set_visible(False)
            plot.axes.set_xlabel('')
            plot.axes.set_xticks(())
            plot.axes.set_xticklabels('')

        # Finally, this actually redraws the plot.
        prj._setup_plots()

    for cax in grid.cbar_axes:
        cax.toggle_label(True)
        # cax.axis[cax.orientation].set_label(field_select)

    if save_plot:

        prj.save(f_out, mpl_kwargs={'bbox_inches': 'tight'})
        print 'dump to ', f_out

    return prj, plot.axes
idx_start = args.idx_start
idx_end = args.idx_end
didx = args.didx
prefix_in = args.prefix_in
prefix_out = args.prefix_out

yt.enable_parallelism()
ts = yt.load([
    prefix_in + '/Data_%06d' % idx
    for idx in range(idx_start, idx_end + 1, didx)
])

for ds in ts.piter():

    #  project data
    plt = yt.OffAxisProjectionPlot(ds,
                                   normal=normal,
                                   fields=field,
                                   north_vector=north)

    #  save image
    plt.set_cmap(field, colormap)
    plt.annotate_timestamp(time_unit='code_time', corner='upper_right')
    plt.save(prefix_out + '_' + ds.basename + '.png', mpl_kwargs={'dpi': dpi})

    #  access the image buffer
    data = plt._frb.data[field]
    print("units = %s" % data.units)
    print("max   = %s" % data.d.max())
Esempio n. 14
0
# %% density
# ----------
field = "density"
fig = plt.figure()
grid = AxesGrid(fig, (0.075,0.075,0.85,0.85),
                nrows_ncols = (2, len(models)),
                axes_pad = 0.05,
                label_mode = "L",
                share_all = True,
                cbar_location="right",
                cbar_mode="single",
                cbar_size="5%",
                cbar_pad="1%")

for i in range(len(models)):
  px = yt.OffAxisProjectionPlot(ds[i], [0,0,1], ('gas', '{}'.format(field)), center = center[i], width=45, north_vector=[0,1,0])
  px.set_xlabel("x [kpc]")
  px.set_ylabel("y [kpc]")
  px.set_font({'size':20})
  px.set_zlim(field="{}".format(field), zmin=5e-6, zmax=1e-1)
  px.annotate_text([-20, 18], models[i][1], coord_system="plot")
  plot = px.plots['{}'.format(field)]
  plot.figure = fig
  plot.axes = grid[i].axes
  plot.cax = grid.cbar_axes[i]
  px.set_background_color(field)
  px._setup_plots()
for i in range(len(models)):
  px = yt.OffAxisProjectionPlot(ds[i], [0,1,0], ('gas', '{}'.format(field)), center = center[i], width=45, north_vector=[0,0,-1])
  px.set_xlabel("x [kpc]")
  px.set_ylabel("z [kpc]")
Esempio n. 15
0
def ProjectionPlot(Param_Dict, worker):
    """Takes a DataSet object loaded with yt and performs a projectionPlot on
    it.
    Parameters:
        Param_Dict: Dict with Parameters
    """
    ds = Param_Dict["CurrentDataSet"]
    field = Param_Dict["ZAxis"]
    if Param_Dict["DomainDiv"]:
        # in case the user wants to divide everything by the domain_height,
        # we define a new field which is just the old field divided by height
        # and then do a projectionPlot for that.
        height = Param_Dict["FieldMaxs"]["DomainHeight"] - Param_Dict["FieldMins"]["DomainHeight"]
        field = "Normed " + field
        unit = yt.units.unit_object.Unit(Param_Dict["ZUnit"] + "/cm")
        realHeight = height.to_value("au")  # Important! Bugs occur if we just used "height"
        def _NormField(field, data):
            return data[Param_Dict["ZAxis"]]/realHeight/yt.units.au
        if Param_Dict["ParticlePlot"]:
            ds.add_field(("io", field), function=_NormField,
                     units="auto", dimensions=unit.dimensions,
                     force_override=True, particle_type=True)
        else:
            ds.add_field(("gas", field), function=_NormField,
                         units="auto", dimensions=unit.dimensions,
                         force_override=True)
    gridUnit = Param_Dict["GridUnit"]
    c0 = yt.YTQuantity(Param_Dict["XCenter"], gridUnit)
    c1 = yt.YTQuantity(Param_Dict["YCenter"], gridUnit)
    if Param_Dict["Geometry"] == "cartesian":
        c2 = yt.YTQuantity(Param_Dict["ZCenter"], gridUnit)
    else:
        c2 = Param_Dict["ZCenter"]
    width = (Param_Dict["HorWidth"], gridUnit)
    height = (Param_Dict["VerWidth"], gridUnit)
    if Param_Dict["ParticlePlot"]:
        plot = yt.ParticleProjectionPlot(ds, Param_Dict["NAxis"], field,
                                         axes_unit=Param_Dict["GridUnit"],
                                         weight_field=Param_Dict["WeightField"],
                                         fontsize=14, center=[c0, c1, c2],
                                         width=(width, height))
    else:
        if Param_Dict["NormVecMode"] == "Axis-Aligned":
            plot = yt.ProjectionPlot(ds, Param_Dict["NAxis"], field,
                                     axes_unit=Param_Dict["GridUnit"],
                                     weight_field=Param_Dict["WeightField"],
                                     fontsize=14, center=[c0, c1, c2],
                                     width=(width, height))
        else:
            normVec = [Param_Dict[axis + "NormDir"] for axis in ["X", "Y", "Z"]]
            northVec = [Param_Dict[axis + "NormNorth"] for axis in ["X", "Y", "Z"]]
            plot = yt.OffAxisProjectionPlot(ds, normVec, field,
                                            north_vector=northVec,
                                            weight_field=Param_Dict["WeightField"],
                                            axes_unit=Param_Dict["GridUnit"],
                                            fontsize=14, center=[c0, c1, c2],
                                            width=(width, height))
    emitStatus(worker, "Setting projection plot modifications")
    # Set min, max, unit log and color scheme:
    setAxisSettings(plot, Param_Dict, "Z")
    plot.zoom(Param_Dict["Zoom"])
    emitStatus(worker, "Annotating the projection plot")
    annotatePlot(Param_Dict, plot)
    finallyDrawPlot(plot, Param_Dict, worker)
Esempio n. 16
0
def ks_law(ytdataset,
           raw_snapshot,
           width,
           max_age,
           image_width=800,
           field="stars",
           n_bins=100,
           r_max=None):
    """
	Use this to compute the ks_law for each individual pixel within a galaxy
	"""

    weight_field = None

    if r_max == None:
        r_max = ytdataset.ds.arr(25, "kpc")

    print "Computing KS law"
    print "ks 1) Gathering initial projections"

    width = ((width.v, width.units), (width.v, width.units))
    image_width = (image_width, image_width)

    # slice of the star particle mass
    normal = ytdataset.get_field_parameter("normal")
    data_source = gen_data_source(0, ytdataset, raw_snapshot, width, width[0],
                                  "kpc")

    plot = yt.OffAxisProjectionPlot(data_source.ds,
                                    normal,
                                    [("gas", "density"),
                                     ("deposit", "%s_density" % field)],
                                    center=ytdataset.center,
                                    width=width,
                                    depth=width[0],
                                    weight_field=weight_field,
                                    north_vector=normal)
    # set the units
    plot.set_axes_unit("kpc")
    plot.set_unit(("gas", "density"), "Msun/pc**2")
    plot.set_unit(("deposit", "%s_density" % field), "Msun/kpc**2")
    images = plot.frb
    gas_image = images[('gas', 'density')]
    star_image = images[('deposit', '%s_density' % field)]

    print "ks 2) collecting data"

    # filtering things withing the first 15 kpc does the trick usuall

    n_bins = gas_image.shape[0]
    bin_width = width[0][0] / float(n_bins)
    left = -width[0][0] / 2.0
    right = width[0][0] / 2.0
    length = np.arange(left, right, bin_width)
    shift_thing = (length[2] - length[1]) / 2.0
    length = length + shift_thing

    # to make radial array, we need to calculate the radial distance of each pixel in a 2d array. Namely.

    length_squared = np.power(length, 2.0)  # power of 2
    radius_squared = np.add.outer(length_squared,
                                  length_squared)  # adds as an outer product
    radius = np.sqrt(radius_squared)

    # now we need to generate some radial bins

    radius_bin_width = radius.max() / float(n_bins)
    radius_bins = np.arange(0, radius.max(), radius_bin_width)

    # flatten the arrays

    flat_radius = radius.flatten()

    # now need to find the indices in which radius is maximum
    radius_filter = (flat_radius < r_max.in_units("kpc").v)

    # ok, now get the data
    gas_density = np.array(gas_image.in_units("Msun/pc**2").value).flatten()
    star_density = np.array(star_image.in_units("Msun/kpc**2").value).flatten()

    # filter by radius
    gas_density = np.log10(gas_density[radius_filter])
    star_density = np.log10(star_density[radius_filter] / max_age)

    return gas_density, star_density
Esempio n. 17
0
                                                   nu=nu,
                                                   proj_axis=proj_axis,
                                                   extend_cells=gc)
            ds_sync = ds

        yt.mylog.info('Making projection plots from file: %s', sync_fname)
        if proj_axis == 'x':
            plot = yt.ProjectionPlot(ds_sync,
                                     proj_axis,
                                     stokes.IQU,
                                     center=[0, 0, 0],
                                     width=width)
        else:
            plot = yt.OffAxisProjectionPlot(ds_sync,
                                            proj_axis,
                                            stokes.IQU,
                                            center=[0, 0, 0],
                                            width=width,
                                            north_vector=[0, 0, 1])
        plot.set_buff_size(res)
        plot.set_axes_unit('kpc')

        # Setting up colormaps
        # Use "hot" for intensity plot and seismic for Q and U plots
        for field in stokes.IQU:
            if 'nn_emissivity_i' in field[1]:
                plot.set_zlim(field, 1E-3 / norm, 1E1 / norm)
                #cmap = plt.cm.get_cmap("algae")
                #cmap.set_bad((80./256., 0.0, 80./256.))
                cmap = plt.cm.hot
                cmap.set_bad('k')
                plot.set_cmap(field, cmap)
Esempio n. 18
0
    gas_ang_mom_norm = gas_ang_mom / gas_ang_mom_tot
    sim_dict.update([('gas L vector', gas_ang_mom_norm)])

    edge_on_gas = np.random.randn(3)
    edge_on_gas -= edge_on_gas.dot(gas_ang_mom_norm) * gas_ang_mom_norm / np.linalg.norm(gas_ang_mom_norm)**2
    sim_dict.update([('gas edge-on', edge_on_gas)])

# Set cylinder
cyl_rad = np.ceil(gal_r90.value)
cyl_h = args.cyl_h
cyl = ds.disk(center = cen_cen, normal = star_ang_mom_norm, radius = (gal_r90), height = (cyl_h, 'kpc'))

# Making the projections: currently each one renders in a bit over half an hour
if False:
    if True:
        prj = yt.OffAxisProjectionPlot(ds, normal = star_ang_mom_norm, fields = ('deposit', 'stars_density'), center = cen_cen, width=(proj_width, 'kpc'), data_source = cyl)
        prj.save('{}/StarsFaceOn{}-z={}.png'.format(figdir, sim, redshift))

    if True:
        prj = yt.OffAxisProjectionPlot(ds, normal = edge_on_dir, fields = ('deposit', 'stars_density'), center = cen_cen, width=(proj_width, 'kpc'), data_source = cyl)
        prj.save('{}/StarsEdgeOn{}-z={}.png'.format(figdir, sim, redshift))

    if True:
        prj = yt.OffAxisProjectionPlot(ds, normal = gas_ang_mom_norm, fields = 'density', center = cen_cen, width=(proj_width, 'kpc'), data_source = cyl)
        prj.save('{}/GasFaceOn{}-z={}.png'.format(figdir, sim, redshift))

    if True:
        prj = yt.OffAxisProjectionPlot(ds, normal = edge_on_gas, fields = 'density', center = cen_cen, width=(proj_width, 'kpc'), data_source = cyl)
        prj.save('{}/GasEdgeOn{}-z={}.png'.format(figdir, sim, redshift))

# Making profile plots for velocity dispersion in the cylinder
def worker_fn(file):
    ds = yt.load(file.fullpath)
    #for nu in nus:
    #    write_synchrotron_hdf5(ds, ptype, nu, proj_axis)#, extend_cells=None)

    if not os.path.isfile(synchrotron_filename(ds, extend_cells=gc)):
        return ds.directory, 'sync file not found %s' % synchrotron_filename(
            ds, extend_cells=gc)
    maindir = os.path.join(file.pathname, 'cos_synchrotron_QU_nn_%s/' % ptype)
    if proj_axis != 'x':
        maindir = os.path.join(maindir, '%i_%i_%i' % tuple(proj_axis))
        histdir = os.path.join(
            maindir, 'histogram_gaussian%i_%i%i%i' % (sigma, *proj_axis))
    else:
        histdir = os.path.join(maindir,
                               'histogram_gaussian%i_%s' % (sigma, proj_axis))

    ds_sync = yt.load(synchrotron_filename(ds, extend_cells=gc))
    width = ds_sync.domain_width[1:] / zoom_fac
    res = ds_sync.domain_dimensions[
        1:] * ds_sync.refine_by**ds_sync.index.max_level // zoom_fac // 2
    psi, frac = {}, {}
    I_bin = {}
    for nu in nus:
        stokes = StokesFieldName(ptype, nu, proj_axis, field_type='flash')
        if proj_axis == 'x':
            proj = yt.ProjectionPlot(ds_sync,
                                     proj_axis,
                                     stokes.IQU,
                                     center=[0, 0, 0],
                                     width=width).set_buff_size(res)
        else:
            proj = yt.OffAxisProjectionPlot(
                ds_sync,
                proj_axis,
                stokes.IQU,
                width=width,
                north_vector=north_vector).set_buff_size(res)

        frb_I = proj.frb.data[stokes.I].v
        frb_Q = proj.frb.data[stokes.Q].v
        frb_U = proj.frb.data[stokes.U].v
        #null = plt.hist(np.log10(arri.flatten()), range=(-15,3), bins=100)

        frb_I = gaussian_filter(frb_I, sigma)
        frb_Q = gaussian_filter(frb_Q, sigma)
        frb_U = gaussian_filter(frb_U, sigma)

        factor = 1
        nx = res[0] // factor
        ny = res[1] // factor

        I_bin[nu] = frb_I.reshape(ny, factor, nx, factor).sum(3).sum(1)
        Q_bin = frb_Q.reshape(ny, factor, nx, factor).sum(3).sum(1)
        U_bin = frb_U.reshape(ny, factor, nx, factor).sum(3).sum(1)

        # angle between the polarization and horizontal axis
        # (or angle between the magnetic field and vertical axis
        psi[nu] = 0.5 * np.arctan2(U_bin, Q_bin)
        frac[nu] = np.sqrt(Q_bin**2 + U_bin**2) / I_bin[nu]

        fig = plt.figure(figsize=(8, 16))
        i_plot = fig.add_subplot(111)
        i_plot.imshow(np.log10(frb_I + 1e-2), vmin=-1, vmax=1, origin='lower')

        xx0, xx1 = i_plot.get_xlim()
        yy0, yy1 = i_plot.get_ylim()
        X, Y = np.meshgrid(np.linspace(xx0, xx1, nx, endpoint=True),
                           np.linspace(yy0, yy1, ny, endpoint=True))

        mask = I_bin[nu] < 0.1

        frac[nu][mask] = 0
        psi[nu][mask] = 0

        pixX = frac[nu] * np.cos(psi[nu])  # X-vector
        pixY = frac[nu] * np.sin(psi[nu])  # Y-vector

        # keyword arguments for quiverplots
        quiveropts = dict(headlength=0, headwidth=1, pivot='middle')
        i_plot.quiver(X, Y, pixX, pixY, scale=8, **quiveropts)
        nu_str = '%.1f%s' % nu
        fig.savefig(histdir + '/' + ds.basename + '_I_%s.png' % nu_str)

    fig = plt.figure(figsize=(16, 4))

    def plot_polarization_histogram(frac, psi, I_bin, fig=None, label=None):

        if not fig:
            fig = plt.figure(figsize=(16, 4))

        ax1 = fig.axes[0]
        null = ax1.hist(frac[I_bin.nonzero()].flatten() * 100,
                        range=(0, 80),
                        bins=40,
                        alpha=0.5,
                        weights=I_bin[I_bin.nonzero()].flatten(),
                        normed=True)
        ax1.set_xlabel('Polarization fraction (%)')
        ax1.set_xlim(0, 80)

        ax2 = fig.axes[1]
        null = ax2.hist(psi[I_bin.nonzero()].flatten(),
                        bins=50,
                        range=(-0.5 * np.pi, 0.5 * np.pi),
                        alpha=0.5,
                        weights=I_bin[I_bin.nonzero()].flatten(),
                        normed=True)
        x_tick = np.linspace(-0.5, 0.5, 5, endpoint=True)

        x_label = [r"$-\pi/2$", r"$-\pi/4$", r"$0$", r"$+\pi/4$", r"$+\pi/2$"]
        ax2.set_xlim(-0.5 * np.pi, 0.5 * np.pi)
        ax2.set_xticks(x_tick * np.pi)
        ax2.set_xticklabels(x_label)
        #ax2.set_title(ds.basename + '  %.1f %s' % nu)

        ax3 = fig.axes[2]
        null = ax3.hist(np.abs(psi[I_bin.nonzero()].flatten()),
                        bins=25,
                        range=(0.0, 0.5 * np.pi),
                        alpha=0.5,
                        label=label)
        ax3.legend()
        ax3.set_xlim(0.0, 0.5 * np.pi)
        ax3.set_xticks([x_tick[2:] * np.pi])
        ax3.set_xticks(x_tick[2:] * np.pi)
        ax3.set_xticklabels(x_label[2:])

        return fig

    fig = plt.figure(figsize=(16, 4))
    ax1 = fig.add_subplot(131)
    ax2 = fig.add_subplot(132)
    ax3 = fig.add_subplot(133)

    for nu in nus:
        nu_str = '%.1f%s' % nu
        fig = plot_polarization_histogram(frac[nu],
                                          psi[nu],
                                          I_bin[nu],
                                          fig=fig,
                                          label=nu_str)

    ax1.set_title(file.pathname)
    ax2.set_title(ds.basename + '  %.1f %s' % nu + ' gaussian %i' % sigma)

    fig.savefig(histdir + '/' + ds.basename)

    return file.pathname, ds.basename[-4:]
Esempio n. 20
0
import yt

# Load the dataset.
ds = yt.load("IsolatedGalaxy/galaxy0030/galaxy0030")

# Create a 15 kpc radius sphere, centered on the center of the sim volume
sp = ds.sphere("center", (15.0, "kpc"))

# Get the angular momentum vector for the sphere.
L = sp.quantities.angular_momentum_vector()

print(f"Angular momentum vector: {L}")

# Create an OffAxisProjectionPlot of density centered on the object with the L
# vector as its normal and a width of 25 kpc on a side
p = yt.OffAxisProjectionPlot(ds, L, "density", sp.center, (25, "kpc"))
p.save()
Esempio n. 21
0
                    Zmetal = data['PartType0', 'Metallicity_00']
                    ZHe = data['PartType0', 'Metallicity_01']
                    mass = data['PartType0', 'Masses'].in_cgs()
                    density = data['PartType0', 'Density'].in_cgs()
                    protonmass_in_g_yt = ds.quan(protonmass_in_g,'g')
                    edensity = density/protonmass_in_g_yt*(1.0-Zmetal-ZHe)*ne
                    print 'edensity', edensity
                    return edensity
                ds.add_field(('PartType0', 'electrondensity'), function=_electrondensity,
                    units="1/cm**3", particle_type=True)

                add_volume_weighted_smoothed_field('PartType0', 'Coordinates', 'Masses',
                                   'SmoothingLength', 'Density',
                                   'electrondensity', ds.field_info)
                if mode=='projected':
                        px = yt.OffAxisProjectionPlot(ds, projectionaxis, ('deposit', 'PartType0_smoothed_electrondensity'), center=center, width=new_box_size)
                if mode=='Slice':
                        if rotface==1:
                                px = yt.OffAxisSlicePlot(ds, projectionaxis, ('deposit', 'PartType0_smoothed_electrondensity'), center=center, width=new_box_size,north_vector=north_vector)
                        else:
                                px = yt.SlicePlot(ds, projectionaxis, ('deposit', 'PartType0_smoothed_electrondensity'), center=center, width=new_box_size)
                #px.set_buff_size(128)
                px.set_buff_size(64)
                # Extract X, Y, U, V from the frb
                px_frb = px.frb
                px_dens = np.array(px_frb[('deposit', 'PartType0_smoothed_electrondensity')])
                cbcolor = 'YlGnBu'
                if mode=='projected': 
                    cblabel = r'${\rm \log (n_{\rm e} [cm^{-2}])}$'                
                if mode=='Slice': 
                    cblabel = r'${\rm \log (n_{\rm e} [cm^{-3}])}$'
Esempio n. 22
0
def plot_a_region(
    ram,
    out,
    ds,
    center,
    fields='den',
    kind='slc',
    axis='z',
    width=None,
    center_vel=None,
    L=None,
    direcs='face',
    zlims={},
    l_max=None,
    is_id=False,
    bar_length=2e3,
    sketch=False,
    time_offset='tRelax',
    is_set_size=True,
    is_time=True,
    streamplot_kwargs={},
    more_kwargs={},
):
    """Do projection or slice plots in a region.

    Args:
        ram: a ramtools.Ramses instance
        out (int): output frame
        ds (yt.ds): yt.load instance
        center (list_like): the center of the region like in yt.SlicePlot
        fields (str or tuple): the field to plot. The avaialble options are: 'den' or 'density' - density. 'logden' - log of density. 'T' or 'temperature' - temperature. 'T_vel_rela' - temperature overplot with velocity field. 'pressure' - thermal pressure. 'magstream' - stream lines of magnetic fields on top of density slice. The magnetic strength is be indicated by the color of the stream lines. 'beta' - plasma beta parameter. 'AlfMach' - Alfvenic Mach number. 'xHII' - hydrogen ionization fraction. 'mach' - thermal Mach number. 'mag' - magnetic field strength. 'vel' - velocity field on top of density slice. 'vel_rela' - relative velocity field, i.e. the velocity field in the frame with a velocity defined by center_vel. 'mach2d' - thermal Mach number in the plane. 'p_mag' - magnetic pressure. 
        kind (str): 'slc' or 'prj'. Default: 'slc'
        axis (str or int): One of (0, 1, 2, 'x', 'y', 'z').
        width (float or tuple): width of the field of view. e.g. 0.01,
            (1000, 'AU').
        center_vel (list_like or None): the velocity of the center. (Default
            None)
        L (list_like): the line-of-sight vector. Will overwrite axis.
            Default: None
        direcs (str): 'face' or 'edge'. Will only be used if L is not None.
        zlims (dict): The limits of the fields. e.g. {'density': [1e4, 1e8]},
            {'B': [1e-5, 1e-2]} (the default of B field).  
        l_max (int):
        is_id (bool): toggle marking sink ID (the order of creation)
        bar_length (float): not using
        sketch (bool): If True, will do a faster plot of a simple imshow for
            test purpose. Default: False.
        is_time (bool): toggle overplotting time tag
        time_offset (str or int or float): One of the following cases:
            str: 'rRelax'
            int: the frame of which the time is used
            float: time in Myr
        is_set_size (bool): toggle setting figure size to 6 to make texts
            bigger. Default: True.
        streamplot_kwargs (dict): More kwargs for plt.streamplot that is used when fields='magstream'. Default: {}
        more_kwargs (dict): more field-specific kwargs. Default: {}

    Returns:
        yt figure or plt.figure

    """

    from matplotlib import colors

    r = ram
    ytfast.set_data_dir(os.path.join(r.ramses_dir, 'h5_data'))
    assert width is not None
    width = to_boxlen(width, r.ds1)
    if isinstance(time_offset, int):
        time_offset = r.get_time(time_offset)
    elif isinstance(time_offset, str):
        if time_offset is 'tRelax':
            time_offset = r.tRelax
    else:
        time_offset = time_offset
    if l_max is None: l_max = ds.max_level
    use_h5 = False  # TODO: enable use_h5

    # calculated north and right, if L is not None
    if L is None:
        if axis in [0, 1, 2]:
            axis = ['x', 'y', 'z'][axis]
        is_onaxis = True
        los = axis
        axismap = {'x': [1, 0, 0], 'y': [0, 1, 0], 'z': [0, 0, 1]}
        los_v = axismap[axis]
        north = axismap[{'x': 'z', 'y': 'x', 'z': 'y'}[axis]]
        logger.debug('is_onaxis = {}'.format(is_onaxis))
    else:
        is_onaxis = False
        L /= norm(L)
        tmpright = np.cross([0, 0, 1], L)
        tmpright = tmpright / norm(tmpright)
        if direcs == 'face':
            los = L
            north = -1 * tmpright
        elif direcs == 'edge':
            los = tmpright
            north = L
        los_v = los
    right = np.cross(north, los_v)
    logger.debug('is_onaxis = {}'.format(is_onaxis))
    logger.info(f'doing {r.jobPath}, los={los}, north={north}, right={right}')

    # ------------------ doing the plot --------------------
    if 'density' in zlims:
        den_zlim = zlims['density']
    elif 'den' in zlims:
        den_zlim = zlims['den']
    else:
        den_zlim = [None, None]
    gas_field = fields
    is_log = True
    zlim = None
    cb_label = None
    cmap = 'viridis'
    if fields in ['T', 'temperature', 'T_vel_rela']:
        gas_field = 'temperature'
    if fields in ['pressure', 'p_mag']:
        cmap = 'gist_heat'
    if fields in ['beta', 'magbeta', 'AlfMach']:
        cmap = 'seismic'
        add_beta_corrected(ds)
        zlim = [-3, 3]
        is_log = False
        if fields == 'magbeta':
            gas_field = "logPlasmaBeta"

            def _logBeta(field, data):
                return np.log10(data['beta'])

            ds.add_field(("gas", gas_field), function=_logBeta)
            cb_label = "log (Plasma Beta)"
        if fields == 'AlfMach':
            gas_field = ("gas", "LogMachA")

            def _LogMachA(field, data):
                """ MachA = v_rms / v_A = gamma / 2 * Mach^2 * beta """
                gamma = 5 / 3
                return np.log10(gamma / 2 * data["mach"]**2 * data["beta"])

            ds.add_field(gas_field, function=_LogMachA)
            cb_label = "log $M_A$"
    if fields in ['xHII']:
        zlim = [1e-6, 1]
    if fields in ['mach', 'AlfMach']:
        add_mach(ds, center_vel)
        if gas_field == 'mach':
            gas_field = 'logmach'
            add_logmach(ds)
            cb_label = r'log ($\mathcal{M}_s$)'
            is_log = False
            cmap = 'seismic'
            zlim = [-2, 2]
    if fields in ['den', 'density', 'mag', 'vel', 'vel_rela']:
        gas_field = 'density'
    if fields in ['logden', ('gas', 'logden')]:

        def _logden(_field, data):
            mu = 1.4
            return data[('gas', 'density')] / (
                mu * yt.physical_constants.mass_hydrogen)

        gas_field = ('gas', 'logden')
        ds.add_field(gas_field,
                     function=_logden,
                     take_log=True,
                     sampling_type='cell',
                     units="cm**(-3)")
    if fields == 'magstream':

        def _rel_B_1(_field, data):
            Bx = (data["x-Bfield-left"] + data["x-Bfield-right"]) / 2.
            By = (data["y-Bfield-left"] + data["y-Bfield-right"]) / 2.
            Bz = (data["z-Bfield-left"] + data["z-Bfield-right"]) / 2.
            B = yt.YTArray([Bx, By, Bz])
            return np.tensordot(right, B, 1)

        def _rel_B_2(_field, data):
            Bx = (data["x-Bfield-left"] + data["x-Bfield-right"]) / 2.
            By = (data["y-Bfield-left"] + data["y-Bfield-right"]) / 2.
            Bz = (data["z-Bfield-left"] + data["z-Bfield-right"]) / 2.
            B = yt.YTArray([Bx, By, Bz])
            return np.tensordot(north, B, 1)

        ds.add_field(('gas', 'rel_B_1'), function=_rel_B_1)
        ds.add_field(('gas', 'rel_B_2'), function=_rel_B_2)
    if fields in ['vel_rela', 'mach2d']:
        assert north is not None

        def _rel_vel_1(_field, data):
            rel_vel = yt.YTArray([
                data['velocity_x'] - yt.YTArray(center_vel[0], 'cm/s'),
                data['velocity_y'] - yt.YTArray(center_vel[1], 'cm/s'),
                data['velocity_z'] - yt.YTArray(center_vel[2], 'cm/s')
            ])
            return yt.YTArray(np.tensordot(right, rel_vel, 1), 'cm/s')

        def _rel_vel_2(_field, data):
            rel_vel = yt.YTArray([
                data['velocity_x'] - yt.YTArray(center_vel[0], 'cm/s'),
                data['velocity_y'] - yt.YTArray(center_vel[1], 'cm/s'),
                data['velocity_z'] - yt.YTArray(center_vel[2], 'cm/s')
            ])
            return yt.YTArray(np.tensordot(north, rel_vel, 1), 'cm/s')

        def _mach2d(_field, data):
            rel_speed_sq = data['rel_vel_1'].in_cgs().value ** 2 + \
                           data['rel_vel_2'].in_cgs().value ** 2  # cm/s
            T = data['temperature'].value
            # km/s, assuming gamma=5/3. mu is inside T, therefore this
            # is precise.
            cs = 0.11729 * np.sqrt(T)
            return 1e-5 * np.sqrt(rel_speed_sq) / cs  # dimensionless

        ds.add_field(('gas', 'rel_vel_1'), function=_rel_vel_1, units="cm/s")
        ds.add_field(('gas', 'rel_vel_2'), function=_rel_vel_2, units="cm/s")
        ds.add_field(
            ('gas', 'mach2d'),
            function=_mach2d,
        )
        if fields == 'mach2d':
            gas_field = 'logmach2d'

            def _logmach2d(_field, data):
                return np.log10(data["mach2d"])

            ds.add_field(
                ('gas', gas_field),
                function=_logmach2d,
            )
            cb_label = r'log ($\mathcal{M}_{s, 2D}$)'
            is_log = False
            cmap = "seismic"
            zlim = [-2, 2]

    if fields != 'magstream':
        if kind == 'slc':
            if is_onaxis:
                p = yt.SlicePlot(ds,
                                 los,
                                 gas_field,
                                 width=width,
                                 center=center)
            else:
                p = yt.OffAxisSlicePlot(ds,
                                        los,
                                        gas_field,
                                        width=width,
                                        center=center,
                                        north_vector=north)
        elif kind == 'prj':
            if is_onaxis:
                p = ytfast.ProjectionPlot(ds,
                                          los,
                                          gas_field,
                                          width=width,
                                          center=center,
                                          weight_field='density')
            else:
                print("Doing OffAxis projection plot...")
                p = yt.OffAxisProjectionPlot(ds,
                                             los,
                                             gas_field,
                                             width=width,
                                             center=center,
                                             weight_field='density',
                                             max_level=l_max,
                                             north_vector=north)
                print("Done")
        elif kind == 'colden':
            if is_onaxis:
                p = ytfast.ProjectionPlot(ds,
                                          los,
                                          gas_field,
                                          width=width,
                                          center=center,
                                          weight_field=None)
            else:
                p = yt.OffAxisProjectionPlot(ds,
                                             los,
                                             gas_field,
                                             width=width,
                                             center=center,
                                             weight_field=None,
                                             max_level=l_max,
                                             north_vector=north)
        # p.set_colorbar_label(field, _label)
        if is_set_size:
            p.set_figure_size(6)
        p.set_axes_unit('AU')
        if gas_field in zlims:
            if zlims[gas_field] is not None:
                p.set_zlim(gas_field, *zlims[gas_field])
        if is_time:
            p.annotate_timestamp(
                time_format='{time:.2f} {units}',
                time_offset=time_offset,
                text_args={
                    'fontsize': 8,
                    'color': 'w'
                },
                corner='upper_left',
            )
        # Overplot sink particles
        args = {}
        if 'sink_colors' in more_kwargs.keys():
            args['colors'] = more_kwargs['sink_colors']
        if 'mass_lims' in more_kwargs.keys():
            args['lims'] = more_kwargs['mass_lims']
        if 'colors' in more_kwargs.keys():
            args['colors'] = more_kwargs['colors']
        r.overplot_sink_with_id(p,
                                out,
                                center,
                                width / 2,
                                is_id=is_id,
                                zorder='mass',
                                withedge=1,
                                **args)

        # set cmap, zlim, and other styles
        p.set_log(gas_field, is_log)
        p.set_cmap(gas_field, cmap)
        if zlim is not None:
            p.set_zlim(gas_field, zlim)
        p.set_colorbar_label(gas_field, cb_label)
        if gas_field == 'density':
            den_setup(p, den_zlim, time_offset=time_offset)
        if gas_field == 'temperature':
            T_setup(p, [3, 1e4], time_offset=time_offset)

        if fields == 'vel':
            p.annotate_velocity(factor=30,
                                scale=scale,
                                scale_units='x',
                                plot_args={"color": "cyan"})
            p.set_colorbar_label(gas_field, r'log ($\mathcal{M}_s$)')
        if fields == ['vel_rela', 'T_vel_rela']:
            p.annotate_quiver('rel_vel_1',
                              'rel_vel_2',
                              factor=30,
                              scale=scale,
                              scale_units='x',
                              plot_args={"color": "cyan"})
        if fields in ['vel', 'vel_rela', 'T_vel_rela']:
            # scale ruler for velocities
            ruler_v_kms = 4  # km/s
            width_au = width * r.boxlen * 2e5
            bar_length = 2e3 * width_au / 2e4
            coeff = bar_length  # length of ruler in axis units [AU]
            scale = 1e5 * ruler_v_kms / coeff  # number of cm/s per axis units
            # add a scale ruler
            p.annotate_scale(coeff=coeff,
                             unit='AU',
                             scale_text_format="{} km/s".format(ruler_v_kms))
        return p
    else:  # streamplot
        if not sketch:
            sl = ds.cutting(los_v, center, north_vector=north)
            hw = width / 2.
            bounds = [-hw, hw, -hw, hw]
            size = 2**10
            frb = yt.FixedResolutionBuffer(sl, bounds, [size, size])
            m_h = 1.6735575e-24
            mu = 1.4
            den = frb['density'].value / (mu * m_h)
            unitB = utilities.get_unit_B(ds)
            logger.info(f"unitB.value = {unitB.value}")
            Bx = frb['rel_B_1'].value * unitB.value  # Gauss
            By = frb['rel_B_2'].value * unitB.value  # Gauss
            Bmag = np.sqrt(Bx**2 + By**2)
            logger.info(f"Bmag min = {Bmag.min()}, max = {Bmag.max()}")
            hw_au = hw * r.unit_l / AU
            grid_base = np.linspace(-hw_au, hw_au, size)
            bounds_au = [-hw_au, hw_au, -hw_au, hw_au]
        else:
            den = np.random.random([3, 3])
            hw_au = 10
            den_zlim = [None, None]
            bounds_au = [-hw_au, hw_au, -hw_au, hw_au]

        # fig, ax = plt.subplots()
        if 'figax' in more_kwargs.keys():
            fig, ax = more_kwargs['figax']
        else:
            fig, ax = plt.subplots()
        im = ax.imshow(np.log10(den),
                       cmap='inferno',
                       extent=bounds_au,
                       vmin=den_zlim[0],
                       vmax=den_zlim[1],
                       origin='lower')
        colornorm_den = colors.Normalize(vmin=den_zlim[0], vmax=den_zlim[1])
        ax.set(xlabel="Image x (AU)",
               ylabel="Image y (AU)",
               xlim=[-hw_au, hw_au],
               ylim=[-hw_au, hw_au])
        cmap = mpl.cm.get_cmap(
            "Greens",
            6,
        )
        Bvmin, Bvmax = -5, -2
        if "Blims_log" in more_kwargs.keys():
            Bvmin, Bvmax = more_kwargs["Blims_log"]
        if "B" in zlims.keys():
            Bvmin = np.log10(zlims['B'][0])
            Bvmax = np.log10(zlims['B'][1])
        colornorm_stream = colors.Normalize(vmin=Bvmin, vmax=Bvmax)
        # streamplot_kwargs = {}
        # if "streamplot" in more_kwargs.keys():
        #     streamplot_kwargs = more_kwargs["streamplot"]
        # linewidth = 0.2
        if "stream_linewidth" in more_kwargs.keys():
            linewidth = more_kwargs["stream_linewidth"]
        if not sketch:
            logmag = np.log10(Bmag)
            # scaled_mag = (logmag - Bvmin) / (Bvmax - Bvmin)
            strm = ax.streamplot(
                grid_base,
                grid_base,
                Bx,
                By,
                color=np.log10(Bmag),
                density=1.2,
                # linewidth=linewidth,
                # linewidth=1.5 * scaled_mag,
                # cmap='Greens',
                cmap=cmap,
                norm=colornorm_stream,
                arrowsize=0.5,
                **streamplot_kwargs,
            )
        # plt.subplots_adjust(right=0.8)
        plot_cb = True
        plot_cb2 = True
        if 'plot_cb' in more_kwargs.keys():
            plot_cb = more_kwargs['plot_cb']
        if 'plot_cb2' in more_kwargs.keys():
            plot_cb2 = more_kwargs['plot_cb2']
        if plot_cb:
            if 'cb_axis' in more_kwargs.keys():
                ax2 = more_kwargs['cb_axis']
                # pos0 = ax.get_position()
                # cbaxis = fig.add_axes([pos0.x1, pos0.y0, 0.06, pos0.height])
                # cb = plt.colorbar(im, cax=cbaxis)
                cb2 = mpl.colorbar.ColorbarBase(
                    ax2,
                    orientation='vertical',
                    cmap=mpl.cm.get_cmap("inferno"),
                    norm=colornorm_den,
                )
            else:
                cb2 = fig.colorbar(im, ax=ax, pad=0)
            cb2.set_label("log $n$ (cm$^{-3}$)")
        if plot_cb2:
            if 'cb2_axis' in more_kwargs.keys():
                ax2 = more_kwargs['cb2_axis']
            elif plot_cb:
                pos1 = ax.get_position()
                ax2 = fig.add_axes([pos1.x1 + .15, pos1.y0, 0.02, pos1.height])
            else:
                pos1 = ax.get_position()
                ax2 = fig.add_axes([pos1.x0, pos1.y0, 0.02, pos1.height])
            colornorm2 = colors.Normalize(vmin=den_zlim[0], vmax=den_zlim[1])
            cb2 = mpl.colorbar.ColorbarBase(
                ax2,
                orientation='vertical',
                cmap=cmap,
                norm=colornorm_stream,
                # ticks=[-4, -3, -2]
            )
            cb2.set_label("log B (Gauss)")
        # plt.savefig(fs['magstream'], dpi=300)
        return fig
Esempio n. 23
0
def visualisation(viz_type,
                  container,
                  raw_snapshot,
                  module=config.default_module,
                  gas=True,
                  stars=False,
                  dark=False,
                  gas_fields=["temperature", "density"],
                  gas_units=["K", "g/cm**3"],
                  gas_extrema=[[None, None], [None, None]],
                  dark_fields=[('deposit', 'dark_density')],
                  dark_units=["g/cm**3"],
                  dark_extrema=[[None, None]],
                  star_fields=[('deposit', 'stars_density')],
                  star_units=["g/cm**3"],
                  star_extrema=[[None, None]],
                  filter=None,
                  return_objects=False,
                  return_fields=False,
                  callbacks=[],
                  width=width_thing,
                  extra_width=None,
                  depth=depth_thing,
                  name="plot",
                  format=".png",
                  prefix="",
                  normal_vector=[1.0, 0.0, 0.0],
                  axis=[0, 1, 2],
                  plot_images=True,
                  weight_field=None,
                  image_width=1000,
                  extra_image_width=None,
                  suffix=None):
    """
	This routine is designed to handle almost all of the visualisation routines that you can think of.

	this has a disgusting amount of kwargs, so lets break everything down

	arguments

	viz_type = slice, projection, off_axis_slice, off_axis_projection. Not available for all modules all these combinations. But it will know what you mean depending on the config that you load
	container = the data container or object of interest. This will accept either a YT raw snapshot or a pymses raw snapshot
	module = yt or pymses

	kwargs

	raw_snapshot = the raw snapshot of the simulation (i.e the whole box or a larger portion of the box). This may be used to make an additional cut for the datasource
	gas = do we want to plot gas data, default to True
	stars = do we want to plot stellar data, default to false
	dark = do we want to plot dark matter data, default to false

	gas_fields = a list of strings (or turples for some YT magic) of fields that we want to plot
	gas_units = the units of those fields

	dark_fields = see above
	dark_units = see above

	star_fields = see above
	star_units = see above

	filter = if there needs to be any additonal filters used... but really that should be handled before this, but this is more specific for YT
	return_objects = whether you wish to return a list of plot objects and images or not for further work
	callbacks = extra callbacks? this is YT specific
	
	width = YTArray containing the width of the image and its unit
	extra_width = YTArray if you want to do something a bit more fancy with the image dimensions

	depth = projection depth if applicable
	name = name of the plot if you want to rename it
	format = the output format
	
	normal_vector = the normal vector for the off axis plots
	axis = axis for on axis plots
	data_source = YT data source if applicable

	plot_images = if you actually want to plot something

	TODO use regex to rename dark_deposit with io_deposit if age does not exist
	"""

    # customise the width etc

    #	if not [("all","particle_age")] in container:
    #		print "WARNING, DM only run for particles. Viz may fail since the " + \
    #			"age field does not exist in DM only runs"

    axis_unit = width.units
    if extra_width != None:
        width = ((width.v, width.units), (extra_width.v, width.units))
    else:
        width = ((width.v, width.units), (width.v, width.units))

    if extra_image_width != None:
        image_width = (image_width, extra_image_width)
    else:
        image_width = (image_width, image_width)

    depth = (depth.v, depth.units)

    fields = []
    units = []
    extrema = []

    if gas_fields and gas == True:
        fields = fields + gas_fields
        units = units + gas_units
        extrema = extrema + gas_extrema

    if star_fields and stars == True:
        fields = fields + star_fields
        units = units + star_units
        extrema = extrema + star_extrema

    if dark_fields and dark == True:
        fields = fields + dark_fields
        units = units + dark_units
        extrema = extrema + dark_extrema

    # correct length of the extrema
    # usually if this is the case, then we obviously don't care

    if len(fields) > len(extrema):
        diff = len(fields) - len(extrema)
        for i in range(0, diff):
            extrema = extrema + [[None, None]]

    for i in range(0, len(extrema)):
        if extrema[i][0] == None:
            extrema[i][0] = "min"
        if extrema[i][1] == None:
            extrema[i][1] = "max"

    print fields
    plots = {}
    frb = {}

    basis_vectors = ortho_find(normal_vector)
    north_vectors = ortho_find(normal_vector)

    print "basis_vectors", basis_vectors

    # check if basis vectors are in right units

    basis_vectors = YTArray(basis_vectors, "g*kpc**2/s")
    north_vectors = YTArray(north_vectors, "g*kpc**2/s")

    print basis_vectors
    # select the module
    if module == "yt":

        # select the type of plot
        if viz_type == "slice":
            # axis
            print "on axis slice plot"
            if 0 in axis:
                # x

                # to note, a slice plot is pretty irrelevent of the data source.
                # see e.g https://bpaste.net/show/f08dda811ae0
                # which will reproduce the same result
                print "plotting axis 0"
                print width[0][0], "width"
                plot = yt.SlicePlot(container.ds,
                                    0,
                                    fields,
                                    center=container.center,
                                    width=width)
                image = plot.data_source.to_frb(
                    yt.YTQuantity(width[0][0], axis_unit), image_width[0])

                # set the units
                plot.set_axes_unit("kpc")
                for i in range(0, len(fields)):
                    plot.set_unit(fields[i], units[i])
                    plot.set_zlim(fields[i], extrema[i][0], extrema[i][1])

                plots["0_plot"] = plot
                frb["0_frb"] = image

                if plot_images:
                    plot_viz(plot, name, suffix)

            # y
            if 1 in axis:
                # x
                print "plotting axis 1"
                plot = yt.SlicePlot(container.ds,
                                    1,
                                    fields,
                                    center=container.center,
                                    width=width)
                image = plot.data_source.to_frb(
                    yt.YTQuantity(width[0][0], axis_unit),
                    [image_width[0], image_width[1]])

                # set the units
                plot.set_axes_unit("kpc")
                for i in range(0, len(fields)):
                    plot.set_unit(fields[i], units[i])
                    plot.set_zlim(fields[i], extrema[i][0], extrema[i][1])

                plots["1_plot"] = plot
                frb["1_frb"] = image

                if plot_images:
                    plot_viz(plot, name, suffix)
            # z

            if 2 in axis:
                # x
                print "plotting axis 2"
                plot = yt.SlicePlot(container.ds,
                                    2,
                                    fields,
                                    center=container.center,
                                    width=width)
                image = plot.data_source.to_frb(
                    yt.YTQuantity(width[0][0], axis_unit),
                    [image_width[0], image_width[1]])

                # set the units
                plot.set_axes_unit("kpc")
                for i in range(0, len(fields)):
                    plot.set_unit(fields[i], units[i])
                    plot.set_zlim(fields[i], extrema[i][0], extrema[i][1])

                plots["2_plot"] = plot
                frb["2_frb"] = image

                if plot_images:
                    plot_viz(plot, name, suffix)

        if viz_type == "projection":

            # axis
            print "on axis projection plot"
            if 0 in axis:
                # x

                # in this instance.. the depth is along the axis that you are viewing "down".. so in a way is user defined
                data_source = gen_data_source(0, container, raw_snapshot,
                                              width, depth, axis_unit)
                print "plotting axis 0"
                plot = yt.ProjectionPlot(container.ds,
                                         0,
                                         fields,
                                         center=container.center,
                                         width=width,
                                         weight_field=weight_field,
                                         data_source=data_source)
                image = plot.data_source.to_frb(
                    yt.YTQuantity(width[0][0], axis_unit),
                    [image_width[0], image_width[1]])

                # set the units
                plot.set_axes_unit("kpc")
                for i in range(0, len(fields)):
                    plot.set_unit(fields[i], units[i])
                    plot.set_zlim(fields[i], extrema[i][0], extrema[i][1])

                plots["0_plot"] = plot
                frb["0_frb"] = image

                if plot_images:
                    plot_viz(plot, name, suffix)

                # y
            if 1 in axis:
                print "plotting axis 1"
                data_source = gen_data_source(1, container, raw_snapshot,
                                              width, depth, axis_unit)
                plot = yt.ProjectionPlot(container.ds,
                                         1,
                                         fields,
                                         center=container.center,
                                         width=width,
                                         weight_field=weight_field,
                                         data_source=data_source)
                image = plot.data_source.to_frb(
                    yt.YTQuantity(width[0][0], axis_unit),
                    [image_width[0], image_width[1]])

                # set the units
                plot.set_axes_unit("kpc")
                for i in range(0, len(fields)):
                    plot.set_unit(fields[i], units[i])
                    plot.set_zlim(fields[i], extrema[i][0], extrema[i][1])

                plots["1_plot"] = plot
                frb["1_frb"] = image

                if plot_images:
                    plot_viz(plot, name, suffix)

                # z
            if 2 in axis:
                print "plotting axis 2"
                data_source = gen_data_source(2, container, raw_snapshot,
                                              width, depth, axis_unit)
                plot = yt.ProjectionPlot(container.ds,
                                         2,
                                         fields,
                                         center=container.center,
                                         width=width,
                                         weight_field=weight_field,
                                         data_source=data_source)
                image = plot.data_source.to_frb(
                    yt.YTQuantity(width[0][0], axis_unit),
                    [image_width[0], image_width[1]])

                # set the units
                plot.set_axes_unit("kpc")
                for i in range(0, len(fields)):
                    plot.set_unit(fields[i], units[i])
                    plot.set_zlim(fields[i], extrema[i][0], extrema[i][1])

                plots["2_plot"] = plot
                frb["2_frb"] = image

                if plot_images:
                    plot_viz(plot, name, suffix)

        if viz_type == "off_axis_slice":

            # off axis.. may as well do all 3
            # same rules RE the slice apply here too
            print "off axis slice plot"
            if 0 in axis:
                print "plotting axis 0"
                plot = yt.OffAxisSlicePlot(container.ds,
                                           basis_vectors[0],
                                           fields,
                                           center=container.center,
                                           width=width,
                                           north_vector=north_vectors[0])
                image = plot.data_source.to_frb(
                    yt.YTQuantity(width[0][0], axis_unit),
                    [image_width[0], image_width[1]])

                # set the units
                plot.set_axes_unit("kpc")
                for i in range(0, len(fields)):
                    plot.set_unit(fields[i], units[i])
                    plot.set_zlim(fields[i], extrema[i][0], extrema[i][1])

                plots["0_plot"] = plot
                frb["0_frb"] = image

                if plot_images:
                    plot_viz(plot, name + "_axis_0", suffix)

            if 1 in axis:
                print "plotting axis 1"
                plot = yt.OffAxisSlicePlot(container.ds,
                                           basis_vectors[1],
                                           fields,
                                           center=container.center,
                                           width=width,
                                           north_vector=north_vectors[0])
                image = plot.data_source.to_frb(
                    yt.YTQuantity(width[0][0], axis_unit),
                    [image_width[0], image_width[1]])

                # set the units
                plot.set_axes_unit("kpc")
                for i in range(0, len(fields)):
                    plot.set_unit(fields[i], units[i])
                    plot.set_zlim(fields[i], extrema[i][0], extrema[i][1])

                plots["1_plot"] = plot
                frb["1_frb"] = image

                if plot_images:
                    plot_viz(plot, name + "_axis_1", suffix)

            if 2 in axis:
                print "plotting axis 2"
                plot = yt.OffAxisSlicePlot(container.ds,
                                           basis_vectors[2],
                                           fields,
                                           center=container.center,
                                           width=width,
                                           north_vector=north_vectors[0])
                image = plot.data_source.to_frb(
                    yt.YTQuantity(width[0][0], axis_unit),
                    [image_width[0], image_width[1]])

                # set the units
                plot.set_axes_unit("kpc")
                for i in range(0, len(fields)):
                    plot.set_unit(fields[i], units[i])
                    plot.set_zlim(fields[i], extrema[i][0], extrema[i][1])

                plots["2_plot"] = plot
                frb["2_frb"] = image

                if plot_images:
                    plot_viz(plot, name + "_axis_2", suffix)

        if viz_type == "off_axis_projection":

            print "off axis projection plot"

            if 0 in axis:
                print "plotting axis 0"
                # since there is no way of doing an off axis box.. our data source is essentially a cube.. it ignore depth
                # oh, and you probably want this rather than your disk dataset

                #TODO make sure filtered datasts carry across... but it seems to be the case here.
                data_source = gen_data_source(0, container, raw_snapshot,
                                              width, width[0], axis_unit)
                plot = yt.OffAxisProjectionPlot(data_source.ds,
                                                basis_vectors[0],
                                                fields,
                                                center=container.center,
                                                width=width,
                                                depth=width[0],
                                                weight_field=weight_field,
                                                north_vector=north_vectors[0])

                # set the units
                plot.set_axes_unit("kpc")
                for i in range(0, len(fields)):
                    plot.set_unit(fields[i], units[i])
                    plot.set_zlim(fields[i], extrema[i][0], extrema[i][1])
                    plot.annotate_marker((4.0, 6.92820),
                                         coord_system='plot',
                                         plot_args={
                                             'color': 'black',
                                             's': 400
                                         })

                image = plot.frb

                plots["0_plot"] = plot
                frb["0_frb"] = image

                if plot_images:
                    plot_viz(plot, name + "_axis_0", suffix)

            if 1 in axis:
                print "plotting axis 1"
                data_source = gen_data_source(1, container, raw_snapshot,
                                              width, width[0], axis_unit)
                plot = yt.OffAxisProjectionPlot(data_source.ds,
                                                basis_vectors[1],
                                                fields,
                                                center=container.center,
                                                width=width,
                                                depth=width[0],
                                                weight_field=weight_field,
                                                north_vector=north_vectors[0])

                # set the units
                plot.set_axes_unit("kpc")
                for i in range(0, len(fields)):
                    plot.set_unit(fields[i], units[i])
                    plot.set_zlim(fields[i], extrema[i][0], extrema[i][1])

                image = plot.frb

                plots["1_plot"] = plot
                frb["1_frb"] = image

                if plot_images:
                    plot_viz(plot, name + "_axis_1", suffix)

            if 2 in axis:
                print "plotting axis 2"
                data_source = gen_data_source(2, container, raw_snapshot,
                                              width, width[0], axis_unit)
                plot = yt.OffAxisProjectionPlot(data_source.ds,
                                                basis_vectors[2],
                                                fields,
                                                center=container.center,
                                                width=width,
                                                depth=width[0],
                                                weight_field=weight_field,
                                                north_vector=north_vectors[0])

                # set the units
                plot.set_axes_unit("kpc")
                for i in range(0, len(fields)):
                    plot.set_unit(fields[i], units[i])
                    plot.set_zlim(fields[i], extrema[i][0], extrema[i][1])

                image = plot.frb

                plots["2_plot"] = plot
                frb["2_frb"] = image

                if plot_images:
                    plot_viz(plot, name + "_axis_2", suffix)

    else:
        print "module not defined, please either use YT, pynbody or pymses"
        return

    if module == "pymses":
        from pymses.analysis.visualization import *
        from ramses_pp.modules.pymses import PymsesProjection

    if return_objects == True and return_fields == True:
        print plots, frb, fields
        return plots, frb, fields

    if return_objects == True:
        return plots, frb, None

    if return_fields == True:
        return None, None, fields