コード例 #1
0
ファイル: analyze.py プロジェクト: forest80/wdmerger
def vol_renders():
    """Volume render the simulations."""

    import os
    import yt
    yt.enable_parallelism()

    results_base = 'results/approximate/'
    plots_base = 'plots/'

    for mass_P in ['0.90']:
        for mass_S in ['0.60', '0.90']:
            for roche in ['0.90', '1.00']:
                for rot in ['0', '1']:
                    for hybrid in ['0', '1']:
                        for ncell in ['256', '512', '1024']:

                            mass_string = "_m_P_" + mass_P + "_m_S_" + mass_S

                            results_dir = 'mass_P_' + mass_P + '/mass_S_' + mass_S + '/roche' + roche + '/rot' + rot + '/hybrid' + hybrid + '/n' + ncell + '/'

                            output_dir = results_base + results_dir + '/output/'

                            if not os.path.exists(output_dir):
                                continue

                            plot_list = wdmerger.get_plotfiles(output_dir, prefix='smallplt')

                            plot_list = [output_dir + plot for plot in plot_list]

                            plot_dir = plots_base + results_dir

                            if not os.path.exists(plot_dir) and is_root():
                                os.makedirs(plot_dir)

                            ts = yt.DatasetSeries(plot_list)

                            for ds in ts.piter():

                                pltname = ds.basename

                                outfile_name = plot_dir + 'approximate' + mass_string + '_roche_' + roche + '_rot_' + rot + '_hybrid_' + hybrid + '_n_' + ncell + '_' + pltname + '.png'

                                if os.path.isfile(outfile_name):
                                    continue

                                wdmerger.vol_render_density(outfile_name, ds)
コード例 #2
0
'''
This code uses yt's clump finding feature to detect regions of high density of a particualr field. The bounds provided are the default as seen within the documentation, consider tweaking these as desired for better results.
'''

import yt
from yt.analysis_modules.level_sets.api import *
import numpy as np
import os

yt.enable_parallelism()

d_path = '/mnt/research/galaxies-REU/sims/isolated-galaxies/MW_1638kpcBox_800pcCGM_200pcDisk_lowres/DD????/DD????'
d_list = []

lower = 3
upper = 1776
make_dir = False

fields = ['H_p0_number_density']

for i in range(lower, upper):
    path = "DD%04d" % i
    path = d_path + path + "/" + path
    d_list.append(path)

d_series = yt.load(d_path)

axes = ['x', 'z']
storage = {}
if make_dir:
    for axis in axes:
コード例 #3
0
ファイル: yt_analysis.py プロジェクト: ecostriker/py_TIGRESS
                        default='/tigress/changgoo/',
                        help='base working directory')
    parser.add_argument('-d',
                        '--directory',
                        type=str,
                        default='',
                        help='working directory')
    parser.add_argument('-i', '--id', type=str, help='id of dataset')
    parser.add_argument('-p',
                        '--parallel',
                        action='store_true',
                        help='parallel')
    parser.add_argument('-ph',
                        '--phase',
                        action='store_true',
                        help='phase diagram analysis')
    parser.add_argument('-ro',
                        '--rotation',
                        type=float,
                        default=28.,
                        help='rotational velocity')
    parser.add_argument('-r',
                        '--range',
                        type=str,
                        default='',
                        help='time range, start:end:skip')
    args = parser.parse_args()

    if vars(args)['parallel']: yt.enable_parallelism()
    main(**vars(args))
コード例 #4
0
            line = [line for line in lines if '#a =' in line]
            scale_in_file = round(float(line[0].split('=')[-1]), 4)
            if scale == scale_in_file:
                halo_file = this_file
                break
    return halo_file


if __name__ == "__main__":

    args = parse()

    import yt
    from yt.analysis_modules.sunrise_export import sunrise_octree_exporter
    from yt.analysis_modules.halo_finding.halo_objects import RockstarHaloList  
    yt.enable_parallelism()
    
    if yt.is_root():
        print '/nStarting '+ sys.argv[0]
        print 'Parsed arguments: '
        print args
        print

    # Get parsed values
    sim_dirs, snap_base = args['sim_dirs'], args['snap_base']
    print 'Analyzing ', sim_dirs

    out_dir = args['out_dir']
    modify_outdir = 0
    if  'sim_dir' in out_dir: 
        out_dir = out_dir.replace('sim_dir','')
コード例 #5
0
from galaxy_analysis.analysis import time_average_phase_diagram as tapd
from galaxy_analysis.utilities import utilities
from galaxy_analysis.analysis import Galaxy
from galaxy_analysis.yt_fields import field_generators as fg
import yt
import numpy as np
import glob as glob
import os, sys

print(yt.enable_parallelism(), " ------------------------")


def plot_time_average_PD(wdir,
                         t_min,
                         t_max,
                         nbin=100,
                         plots=['nT', 'G_o', 'Q_o'],
                         outdir=None):

    x = utilities.select_data_by_time(wdir, tmin=0.0, tmax=np.inf)
    sim_files = {'files': x[0], 'Time': x[1]}
    #                    'DD_files' : np.sort(glob.glob(wdir + '/DD????/DD????'))}
    #iremove = len(sim_files['files'])
    #sim_files['DD_files'] = sim_files['DD_files'][-iremove:]
    # --- get data sets associated with output files (this is gross - I am sorry)
    sim_files['DD_files'] = np.array([
        x.split('_galaxy_data')[0] + '/' + x.split('_galaxy_data')[0][-6:]
        for x in sim_files['files']
    ])
    # --- make sure they exist
    sim_files['Time'] = np.array([
コード例 #6
0
import sys
import os
matplotlib.use('Agg')
matplotlib.rcParams['font.family'] = 'stixgeneral'
matplotlib.rcParams['figure.dpi'] = 150
import matplotlib.pyplot as plt
import yt
yt.mylog.setLevel("INFO")

from mpi4py import MPI
from mpl_toolkits.axes_grid1 import make_axes_locatable
from particles.particle_filters import *

comm = MPI.COMM_WORLD

yt.enable_parallelism(communicator=comm)

dir = './'

ls = {
    (0, 10): ['solid', 2],
    (10, 20): ['dotted', 2],
    (20, 30): ['dashed', 1],
    (30, 60): ['solid', 1],
    (60, 100): ['dotted', 1]
}

rmin, rmax = -10, 100

try:
    ind = int(sys.argv[1])
コード例 #7
0
# You must run this job in parallel.
# There are several mpi flags which can be useful in order for it to work OK.
# It requires at least 3 processors in order to run because of the way in which
# rockstar divides up the work.  Make sure you have mpi4py installed as per
# http://yt-project.org/docs/dev/analyzing/parallel_computation.html#setting-up-parallel-yt

# Usage: mpirun -np <num_procs> --mca btl ^openib python this_script.py

import yt
from yt.extensions.astro_analysis.halo_analysis.halo_catalog import HaloCatalog
from yt.data_objects.particle_filters import add_particle_filter
from yt.extensions.astro_analysis.halo_finding.rockstar.api import RockstarHaloFinder
yt.enable_parallelism()  # rockstar halofinding requires parallelism

# Create a dark matter particle filter
# This will be code dependent, but this function here is true for enzo


def DarkMatter(pfilter, data):
    filter = data[("all", "particle_type")] == 1  # DM = 1, Stars = 2
    return filter

add_particle_filter("dark_matter", function=DarkMatter, filtered_type='all', \
                    requires=["particle_type"])

# First, we make sure that this script is being run using mpirun with
# at least 3 processors as indicated in the comments above.
assert (yt.communication_system.communicators[-1].size >= 3)

# Load the dataset and apply dark matter filter
fn = "Enzo_64/DD0043/data0043"
コード例 #8
0
def createProfWithTimeForX(Param_Dict, worker):
    """Make a profile plot having the time as the x-Axis.
    Parameters:
        Param_Dict: For the fields and DataSets to be plotted.
    Returns:
        arr: list containing two YTArrays having the time as the first and the
             y-field as second entry
    """
    GUILogger.info("This may take some...time...")
    ts = Param_Dict["DataSeries"]
    timeMin = Param_Dict["XMin"]  # they should already be converted to xunit.
    timeMax = Param_Dict["XMax"]
    times = []
    datasets = []
    emitStatus(worker, "Gathering time data")
    for ds in ts:
        # use the times we have already calculated for each dataset
        time = Param_Dict["DataSetDict"][str(ds) + "Time"].to_value(Param_Dict["XUnit"])
        timecompare = float("{:.3g}".format(time))
        if timeMin <= timecompare <= timeMax:
            times.append(time)
            datasets.append(str(ds))
    GUILogger.log(29, "Iterating over the whole series from {:.3g} to {:.3g} {}..."
          .format(timeMin, timeMax, Param_Dict["XUnit"]))
    calcQuan = getCalcQuanName(Param_Dict)
    field = Param_Dict["YAxis"]
    calcQuanString = getCalcQuanString(Param_Dict)
    storage = {}
    i = 0
    length = len(times)
    if Param_Dict["YAxis"] in Param_Dict["NewDerFieldDict"].keys():
        for ds in ts:
            if str(ds) in datasets:
                try:
                    yResult = Param_Dict["DataSetDict"][str(ds) + field + calcQuan]
                except KeyError:
                    ad = ds.all_data()
                    yResult = eval(calcQuanString)
                    # save the plotpoints for later use
                    value = yt.YTQuantity(yResult, Param_Dict["YUnit"]).to_value(Param_Dict["FieldUnits"][field])
                    Param_Dict["DataSetDict"][str(ds) + field + calcQuan] = value
                storage[str(i)] = yResult  # this is kind of clunky, but this way we don't run into problems later
                i += 1
                progString = f"{i}/{length} data points calculated"
                emitStatus(worker, progString)
                if i % ceil(length/10) == 0:  # maximum of 10 updates
                    GUILogger.info(f"Progress: {progString}.")
    else:  # We want to use parallel iteration if possible
        yt.enable_parallelism(suppress_logging=True)
        newTS = yt.load(Param_Dict["Directory"] + "/" + Param_Dict["Seriesname"])
        for store, ds in newTS.piter(storage=storage):
            try:
                yResult = Param_Dict["DataSetDict"][str(ds) + field + calcQuan]
            except KeyError:
                ad = ds.all_data()  # This is needed for the following command
                yResult = eval(calcQuanString)
                # save the plotpoints for later use
                value = yt.YTQuantity(yResult, Param_Dict["YUnit"]).to_value(Param_Dict["FieldUnits"][field])
                Param_Dict["DataSetDict"][str(ds) + field + calcQuan] = value
            store.result = yResult
            i += 1
            progString = f"{i}/{length} data points calculated"
            emitStatus(worker, progString)
            if i % ceil(length/10) == 0:  # maximum of 10 updates
                GUILogger.info(f"Progress: {progString}.")
    labels = [field]
    # Convert the storage dictionary values to an array, so they can be
    # easily plotted
    arr_x = yt.YTArray(times, Param_Dict["XUnit"])
    arr_y = yt.YTArray(list(storage.values()), Param_Dict["YUnit"])
    arr = [arr_x, arr_y]
#    print(arr)
    return arr, labels
コード例 #9
0
def createMultipleProfiles(Param_Dict, worker):
    """Make a profile plot for each of the requested times and return them so
    they can be plotted.
    Parameters:
        Param_Dict: For the fields and DataSets to be plotted.
    Returns:
        arr: list containing YTArrays having the x-field as the first and the
             y-fields as second and following entries
        labels: the labels for the rows.
    """
    GUILogger.info("This may take some time.")
    onlyEvery = Param_Dict["ProfOfEvery"]
    if onlyEvery == 1:
        numString = ""
    else:
        suf = lambda n: "%d%s "%(n,{1:"st",2:"nd",3:"rd"}.get(n if n<20 else n%10,"th"))
        numString = suf(onlyEvery)
    GUILogger.log(29, "Creating a profile for every {}dataset of the series...".format(numString))
    # the user can input to only plot every nth file:
    yt.enable_parallelism(suppress_logging=True)
    storage = {}
    labels = []
    i = 0
    ts = Param_Dict["DataSeries"]
    length = ceil(len(ts)/onlyEvery)
    if Param_Dict["YAxis"] in Param_Dict["NewDerFieldDict"].keys():
        for ds in ts:
            if i % onlyEvery == 0:
                # Create a data container to hold the whole dataset.
                ad = ds.all_data()
                # Create a 1d profile of xfield vs. yfield:
                prof = yt.create_profile(ad, Param_Dict["XAxis"],
                                         fields=[Param_Dict["YAxis"]],
                                         weight_field=Param_Dict["WeightField"])
                # Add labels
                time = Param_Dict["DataSetDict"][str(ds) + "Time"]
                label = "{} at {:.3g} ".format(Param_Dict["YAxis"], time.value)
                label += str(time.units)
                labels.append(label)
                storage[str(i)] = prof[Param_Dict["YAxis"]]
                progString = f"{int(i/onlyEvery+1)}/{length} profiles done"
                emitStatus(worker, progString)
                if i % ceil(length/10) == 0:  # maximum of 10 updates
                    GUILogger.info(f"Progress: {progString}.")
            i += 1
    else:  # We want to use parallel iteration if possible
        ts = yt.load(Param_Dict["Directory"] + "/" + Param_Dict["Seriesname"])
        for store, ds in ts.piter(storage=storage):
            if i % onlyEvery == 0:
                ad = ds.all_data()
                prof = yt.create_profile(ad, Param_Dict["XAxis"],
                                         fields=[Param_Dict["YAxis"]],
                                         weight_field=Param_Dict["WeightField"])
                # Add labels
                time = Param_Dict["DataSetDict"][str(ds) + "Time"]
                label = "{} at {:.3g} ".format(Param_Dict["YAxis"], time.value)
                label += str(time.units)
                labels.append(label)
                store.result = prof[Param_Dict["YAxis"]]
                progString = f"{int(i/onlyEvery+1)}/{length} profiles done"
                emitStatus(worker, progString)
                GUILogger.info(f"Progress: {progString}.")
            i += 1
    # Convert the storage dictionary values to an array with x-axis as first
    # row and then the results of y-field as following rows.
    arr_x = prof.x
    arr = [arr_x]
    for arr_y in storage.values():
        if arr_y is not None:
            arr.append(arr_y)
    return arr, labels
コード例 #10
0
#!usr/bin/env python
import sys
import os
import yt
yt.mylog.setLevel("INFO")
import yt_synchrotron_emissivity as sync

yt.enable_parallelism(suppress_logging=True)

dir = './'
ptype = 'lobe'
proj_axis = 'x'
#proj_axis = [1,0,2]
extend_cells = 8
nus = [(nu, 'MHz') for nu in [100, 1400, 8000]]

try:
    ind = int(sys.argv[1])
    #ts = yt.DatasetSeries(os.path.join(dir,'*_hdf5_plt_cnt_%04d' % ind), parallel=1, setup_function=sync.setup_part_file)

    # This works for 1 file at a time. Thus parallel=1
    ts = yt.DatasetSeries(os.path.join(dir, 'data/*_hdf5_plt_cnt_%04d' % ind),
                          parallel=1)
except IndexError:
    ts = yt.DatasetSeries(os.path.join(dir, 'data/*_hdf5_plt_cnt_????'),
                          parallel=1)

for ds in ts.piter():
    if '0000' in ds.basename: continue
    for nu in nus:
        # The two projection axes cannot be completed at the same time
コード例 #11
0
anlsT = tst.anlzT(ts, anlzD, outdir="/my/plot/path", plotT=plotT, trypkl=True)

######### END OF EXAMPLE SCRIPT #########
"""

import os
import sys
import numpy as np
if (sys.version_info > (3, 0)):
    # Python 3 code in this block
    import pickle
else:
    # Python 2 code in this block
    import cPickle as pickle
import yt
yt.enable_parallelism() # Tap into yt's mpi4py parallelism (e.g. now can call via mpirun -np 10 python <blah>.py)
yt.funcs.mylog.setLevel(30) # This sets the output notification threshold to 30, WARNING. Default is 20, INFO.
from mpi4py import MPI
import numbers
import inspect

comm = MPI.COMM_WORLD
rank = comm.rank

# TODO: Incorporate this into the new analysis schema
def tsinf(ts):
	""" Walks through the time series and returns a dict with some basic info like time steps """
	numfiles = len(ts)

	ts_inf = {}
	ts_inf['times_ns'] = np.zeros(numfiles)
コード例 #12
0
# You must run this job in parallel.  
# There are several mpi flags which can be useful in order for it to work OK.
# It requires at least 3 processors in order to run because of the way in which 
# rockstar divides up the work.  Make sure you have mpi4py installed as per 
# http://yt-project.org/docs/dev/analyzing/parallel_computation.html#setting-up-parallel-yt
    
# Usage: mpirun -np <num_procs> --mca btl ^openib python this_script.py

import yt
from yt.analysis_modules.halo_analysis.halo_catalog import HaloCatalog
from yt.data_objects.particle_filters import add_particle_filter
from yt.analysis_modules.halo_finding.rockstar.api import RockstarHaloFinder
yt.enable_parallelism() # rockstar halofinding requires parallelism

# Create a dark matter particle filter
# This will be code dependent, but this function here is true for enzo

def DarkMatter(pfilter, data):
    filter = data[("all", "particle_type")] == 1 # DM = 1, Stars = 2
    return filter

add_particle_filter("dark_matter", function=DarkMatter, filtered_type='all', \
                    requires=["particle_type"])

# First, we make sure that this script is being run using mpirun with
# at least 3 processors as indicated in the comments above.
assert(yt.communication_system.communicators[-1].size >= 3)

# Load the dataset and apply dark matter filter
fn = "Enzo_64/DD0043/data0043"
ds = yt.load(fn)
コード例 #13
0
def run_sightlines(outputfilename,save_after_num,parallel,\
                   simulation_dest = None,run = 'default',throwerrors = 'warn'):
    if run not in ['default', 'test']:
        print('unknown option for "run" %s.'%run+\
              ' Please restart with "run = default"'+\
              ' or "run = test".')
    #do not print out anything from yt (it prints plenty)
    yt.funcs.mylog.setLevel(50)
    if parallel:
        yt.enable_parallelism()
    readvalsoutput = simulation_quasar_sphere.read_values(outputfilename)
    #by creating a QuasarSphere, it knows all its metadata and other
    #information from simparams and scanparams (first lines of file at
    #'filename')
    q = simulation_quasar_sphere.SimQuasarSphere(
        start_up_info_packet=readvalsoutput)
    if q.simparams[6] is None:
        if simulation_dest:
            q.simparams[6] = simulation_dest
        else:
            raise NoSimulationError('Simulation file location unknown, '+\
                                    'run with "simulation_dest" to process')
    else:
        simulation_dest = q.simparams[6]
    ds,fields_to_keep = code_specific_setup.load_and_setup(simulation_dest,\
                                                           q.fullname,ions = q.ions,\
                                                           redshift = q.redshift)
    set_up_general(ds, q.code, q.center, q.bulk_velocity, q.Rvir)
    code_specific_setup.check_redshift(ds, outputfilename=outputfilename)
    num_bin_vars = q.gasbins.get_length()
    #Can start at a position further than 0 if reached
    starting_point = q.length_reached
    bins = np.append(np.arange(starting_point, q.length, save_after_num),
                     q.length)
    #first for loop is non-parallel. If 32 processors available, it will break up
    #into bins of size 32 at a time for example. At end, saves data from all 32.
    #this is (~12 bins) in usual circumstances
    for i in range(0, len(bins) - 1):
        current_info = q.info[bins[i]:bins[i + 1]]
        if yt.is_root():
            tprint("%s-%s /%s" % (bins[i], bins[i + 1], len(q.info)))
        my_storage = {}
        #2nd for loop is parallel. Each vector goes to a different processor, and creates
        #a separate trident sightline (~32 sightlines [in a bin]).
        #the longest processing step is ray = trident.make_simple_ray, and it's
        #the only step which actually takes any time (below two for loops go by fast)
        for sto, in_vec in yt.parallel_objects(current_info,
                                               storage=my_storage):
            vector = np.copy(in_vec)
            index = vector[0]
            toprint = "line %s, (r = %.0f) densities " % (str(
                int(index)), vector[3])
            tprint("<line %d, starting process> " % index)
            ident = str(index)
            start = ds.arr(tuple(vector[5:8]), 'unitary')
            end = ds.arr(tuple(vector[8:11]), 'unitary')
            try:
                ray = trident.make_simple_ray(ds,
                                              start_position=start,
                                              end_position=end,
                                              data_filename="ray" + ident +
                                              ".h5",
                                              fields=fields_to_keep,
                                              ftype='gas')
            except KeyboardInterrupt:
                print('skipping sightline %s ...' % index)
                print('Interrupt again within 5 seconds to *actually* end')
                time.sleep(5)
                continue
            except ValueError as e:
                throw_errors_if_allowed(
                    e, throwerrors,
                    'ray has shape %s - %s, but size 0' % (start, end))
            except Exception as e:
                throw_errors_if_allowed(e, throwerrors,
                                        'problem with making ray')
                continue
            trident.add_ion_fields(ray, q.ions)
            field_data = ray.all_data()
            dl = field_data['gas', 'dl']
            #3rd for loop is for processing each piece of info about each ion
            #including how much that ion is in each bin according to gasbinning
            #here just process topline data (column densities and ion fractions)
            #(~10 ions)
            for j in range(len(q.ions)):
                ion = q.ions[j]
                ionfield = field_data["gas", ion_to_field_name(ion)]
                cdens = np.sum((ionfield * dl).in_units('cm**-2')).value
                vector[11 + j * (num_bin_vars + 2)] = cdens
                total_nucleus = np.sum(ionfield[ionfield>0]/\
                                           field_data["gas",ion_to_field_name(ion,'ion_fraction')][ionfield>0]\
                                            * dl[ionfield>0])
                vector[11 + j * (num_bin_vars + 2) + 1] = cdens / total_nucleus
                #4th for loop is processing each gasbin for the current ion
                #(~20 bins)
                for k in range(num_bin_vars):
                    try:
                        variable_name, edges, units = q.gasbins.get_field_binedges_for_num(
                            k, ion)
                        if variable_name is None:
                            vector[11 + j * (num_bin_vars + 2) + k +
                                   2] = np.nan
                        elif variable_name in ray.derived_field_list:
                            if units:
                                data = field_data[variable_name].in_units(
                                    units)
                            else:
                                data = field_data[variable_name]
                            abovelowerbound = data > edges[0]
                            belowupperbound = data < edges[1]
                            withinbounds = np.logical_and(
                                abovelowerbound, belowupperbound)
                            coldens_in_line = (ionfield[withinbounds]) * (
                                dl[withinbounds])
                            coldens_in_bin = np.sum(coldens_in_line)
                            vector[11 + j * (num_bin_vars + 2) + k +
                                   2] = coldens_in_bin / cdens
                        else:
                            print(
                                str(variable_name) +
                                " not in ray.derived_field_list")
                    except Exception as e:
                        throw_errors_if_allowed(
                            e, throwerrors,
                            'Could not bin into %s with edges %s' %
                            (variable_name, edges))
                toprint += "%s:%e " % (ion, cdens)
            #gets some more information from the general sightline.
            #metallicity, average density (over the whole sightline)
            #mass-weighted temperature
            try:
                if ('gas', "H_nuclei_density") in ray.derived_field_list:
                    Z = np.sum(field_data['gas',"metal_density"]*dl)/ \
                        np.sum(field_data['gas',"H_nuclei_density"]*mh*dl)
                else:
                    Z = np.sum(field_data['gas',"metal_density"]*dl)/ \
                        np.sum(field_data['gas',"number_density"]*mh*dl)
                vector[-1] = Z
            except Exception as e:
                throw_errors_if_allowed(e, throwerrors,
                                        'problem with average metallicity')
            try:
                n = np.sum(
                    field_data['gas', 'number_density'] * dl) / np.sum(dl)
                vector[-2] = n
            except Exception as e:
                throw_errors_if_allowed(e, throwerrors,
                                        'problem with average density')
            try:
                T = np.average(field_data['gas','temperature'],\
                               weights=field_data['gas','density']*dl)
                vector[-3] = T
            except Exception as e:
                throw_errors_if_allowed(e, throwerrors,
                                        'problem with average temperature')
            try:
                os.remove("ray" + ident + ".h5")
            except:
                pass
            tprint(toprint)
            #'vector' now contains real data, not just '-1's
            sto.result_id = index
            sto.result = vector
        #save all parallel sightlines after they finish (every 32 lines are saved at once)
        if yt.is_root():
            keys = my_storage.keys()
            for key in keys:
                q.info[int(key)] = my_storage[key]
            q.scanparams[6] += (bins[i + 1] - bins[i])
            q.length_reached = q.scanparams[6]
            if run != 'test':
                outputfilename = q.save_values(oldfilename=outputfilename)
                tprint("file saved to " + outputfilename + ".")
コード例 #14
0
    This script additionally contains a couple of simple example usages for 
    both a typical phase diagram (n,T) and a spatial phase diagram (e.g.
    a gas profile in a galaxy disk).
"""

__author__ = "Andrew Emerick"

from galaxy_analysis import Galaxy
import numpy as np
import yt
import glob
import os

# attempt to enable parallelism
# result is True if this works
parallel_on = yt.enable_parallelism()


def _mag_z(field, data):
    return np.abs(data['cylindrical_z'].to('pc'))


yt.add_field(("gas", "mag_z"), function=_mag_z, units="pc")


def _create_region(ds, region_type, prop):
    """
    Helper function to construct a region in yt given some arguments. This
    is used to ensure consistent regions across data set. Mainly parses user
    input into appropriate region type and args/kwargs with some
    assumptions on defaults if certain fields are not provided.