Esempio n. 1
0
    def __call__(self, args):
        if args.list:
            self.get_list()
            return
        if not args.filename:
            raise RuntimeError("You need to provide a filename. See --help "
                               "for details or use --list to get available "
                               "datasets.")
        elif not args.location:
            raise RuntimeError(
                "You need to specify download location. See --help for details."
            )
        data_url = f"http://yt-project.org/data/{args.filename}"
        if args.location in ["test_data_dir", "supp_data_dir"]:
            data_dir = ytcfg.get("yt", args.location)
            if data_dir == "/does/not/exist":
                raise RuntimeError(f"'{args.location}' is not configured!")
        else:
            data_dir = args.location
        if not os.path.exists(data_dir):
            print(f"The directory '{data_dir}' does not exist. Creating...")
            ensure_dir(data_dir)
        data_file = os.path.join(data_dir, args.filename)
        if os.path.exists(data_file) and not args.overwrite:
            raise OSError(f"File '{data_file}' exists and overwrite=False!")
        print(f"Attempting to download file: {args.filename}")
        fn = download_file(data_url, data_file)

        if not os.path.exists(fn):
            raise OSError(f"The file '{args.filename}' did not download!!")
        print(f"File: {args.filename} downloaded successfully to {data_file}")
Esempio n. 2
0
def get_my_tree_rare():
    """
    Halo with the most black holes from the normal region.
    """

    ### Get the tree with the most black holes.
    ds = yt.load("halo_catalogs_nosub/RD0041/RD0041.0.h5")
    a = ytree.load("rockstar_halos/trees/tree_0_0_0.dat")

    # halo with highest relative growth
    gsort = ds.r["relative_growth_max"].argsort()[::-1]
    hid = ds.r["particle_identifier"][gsort][0].d

    # # halo with second highest relative growth
    # gsort = ds.r["relative_growth_max"].argsort()[::-1]
    # hid = ds.r["particle_identifier"][gsort][1].d

    t = a[a["Orig_halo_ID"] == hid][0]
    uid = t["uid"]

    data_dir = "halo_%d" % uid
    ensure_dir(data_dir)

    fn = t.save_tree(filename=os.path.join(data_dir, "tree_%d" % uid))
    return fn
Esempio n. 3
0
    def add_callback(self, callback, *args, **kwargs):
        r"""
        Add a callback to the halo catalog action list.

        A callback is a function that accepts and operates on a Halo object and
        does not return anything.  Callbacks must exist within the callback_registry.
        Give additional args and kwargs to be passed to the callback here.

        Parameters
        ----------
        callback : string
            The name of the callback.

        Examples
        --------

        >>> # Here, a callback is defined and added to the registry.
        >>> def _say_something(halo, message):
        ...     my_id = halo.quantities['particle_identifier']
        ...     print "Halo %d: here is a message - %s." % (my_id, message)
        >>> add_callback("hello_world", _say_something)

        >>> # Now this callback is accessible to the HaloCatalog object
        >>> hc.add_callback("hello_world", "this is my message")

        """
        callback = callback_registry.find(callback, *args, **kwargs)
        if "output_dir" in kwargs is not None:
            ensure_dir(os.path.join(self.output_dir, kwargs["output_dir"]))
        self.actions.append(("callback", callback))
Esempio n. 4
0
    def __init__(self, parameter_filename, simulation_type,
                 near_redshift, far_redshift,
                 observer_redshift=0.0,
                 use_minimum_datasets=True, deltaz_min=0.0,
                 minimum_coherent_box_fraction=0.0,
                 time_data=True, redshift_data=True,
                 find_outputs=False, set_parameters=None,
                 output_dir="LC", output_prefix="LightCone"):

        self.near_redshift = near_redshift
        self.far_redshift = far_redshift
        self.observer_redshift = observer_redshift
        self.use_minimum_datasets = use_minimum_datasets
        self.deltaz_min = deltaz_min
        self.minimum_coherent_box_fraction = minimum_coherent_box_fraction
        if set_parameters is None:
            self.set_parameters = {}
        else:
            self.set_parameters = set_parameters
        self.output_dir = output_dir
        self.output_prefix = output_prefix

        # Create output directory.
        ensure_dir(self.output_dir)

        # Calculate light cone solution.
        CosmologySplice.__init__(self, parameter_filename, simulation_type,
                                 find_outputs=find_outputs)
        self.light_cone_solution = \
          self.create_cosmology_splice(self.near_redshift, self.far_redshift,
                                       minimal=self.use_minimum_datasets,
                                       deltaz_min=self.deltaz_min,
                                       time_data=time_data,
                                       redshift_data=redshift_data)
    def add_callback(self, callback, *args, **kwargs):
        r"""
        Add a callback to the halo catalog action list.

        A callback is a function that accepts and operates on a Halo object and
        does not return anything.  Callbacks must exist within the callback_registry.
        Give additional args and kwargs to be passed to the callback here.

        Parameters
        ----------
        callback : string
            The name of the callback.

        Examples
        --------

        >>> # Here, a callback is defined and added to the registry.
        >>> def _say_something(halo, message):
        ...     my_id = halo.quantities['particle_identifier']
        ...     print "Halo %d: here is a message - %s." % (my_id, message)
        >>> add_callback("hello_world", _say_something)

        >>> # Now this callback is accessible to the HaloCatalog object
        >>> hc.add_callback("hello_world", "this is my message")

        """
        callback = callback_registry.find(callback, *args, **kwargs)
        if "output_dir" in kwargs is not None:
            ensure_dir(os.path.join(self.output_dir, kwargs["output_dir"]))
        self.actions.append(("callback", callback))
Esempio n. 6
0
    def save(self, name=None, suffix=None, mpl_kwargs=None):
        """saves the plot to disk.

        Parameters
        ----------
        name : string
           The base of the filename.  If name is a directory or if name is not
           set, the filename of the dataset is used.
        suffix : string
           Specify the image type by its suffix. If not specified, the output
           type will be inferred from the filename. Defaults to PNG.
        mpl_kwargs : dict
           A dict of keyword arguments to be passed to matplotlib.

        >>> slc.save(mpl_kwargs={'bbox_inches':'tight'})

        """
        names = []
        if mpl_kwargs is None: mpl_kwargs = {}
        if name is None:
            name = str(self.ds)
        name = os.path.expanduser(name)
        if name[-1] == os.sep and not os.path.isdir(name):
            ensure_dir(name)
        if os.path.isdir(name) and name != str(self.ds):
            name = name + (os.sep if name[-1] != os.sep else '') + str(self.ds)
        if suffix is None:
            suffix = get_image_suffix(name)
            if suffix != '':
                for k, v in iteritems(self.plots):
                    names.append(v.save(name, mpl_kwargs))
                return names
        if hasattr(self.data_source, 'axis'):
            axis = self.ds.coordinates.axis_name.get(
                self.data_source.axis, '')
        else:
            axis = None
        weight = None
        type = self._plot_type
        if type in ['Projection', 'OffAxisProjection']:
            weight = self.data_source.weight_field
            if weight is not None:
                weight = weight[1].replace(' ', '_')
        if 'Cutting' in self.data_source.__class__.__name__:
            type = 'OffAxisSlice'
        for k, v in iteritems(self.plots):
            if isinstance(k, tuple):
                k = k[1]
            if axis:
                n = "%s_%s_%s_%s" % (name, type, axis, k.replace(' ', '_'))
            else:
                # for cutting planes
                n = "%s_%s_%s" % (name, type, k.replace(' ', '_'))
            if weight:
                n += "_%s" % (weight)
            if suffix != '':
                n = ".".join([n,suffix])
            names.append(v.save(n, mpl_kwargs))
        return names
Esempio n. 7
0
def _determine_output_filename(path, suffix):
    if path.endswith(suffix):
        dirname = os.path.dirname(path)
        filename = path[:-len(suffix)]
    else:
        dirname = path
        filename = os.path.join(dirname, os.path.basename(path))
    ensure_dir(dirname)
    return filename
Esempio n. 8
0
def get_my_tree():
    """
    Halo with the most black holes from the normal region.
    """

    ### Get the tree with the most black holes.
    ds = yt.load("halo_catalogs/RD0077/RD0077.0.h5")
    a = ytree.load("rockstar_halos/trees/tree_0_0_0.dat")

    i = ds.r["n_black_holes"].argmax()
    hid = ds.r["particle_identifier"][i].d

    t = a[a["Orig_halo_ID"] == hid][0]
    uid = t["uid"]

    data_dir = "halo_%d" % uid
    ensure_dir(data_dir)

    fn = t.save_tree(filename=os.path.join(data_dir, "tree_%d" % uid))
    return fn
    def __init__(
        self,
        halos_ds=None,
        data_ds=None,
        data_source=None,
        halo_field_type="all",
        finder_method=None,
        finder_kwargs=None,
        output_dir=None,
    ):

        super().__init__()

        self.halos_ds = halos_ds
        self.data_ds = data_ds
        self.halo_field_type = halo_field_type

        if halos_ds is None:
            if data_ds is None:
                raise RuntimeError(
                    "Must specify a halos_ds, data_ds, or both.")
            if finder_method is None:
                raise RuntimeError(
                    "Must specify a halos_ds or a finder_method.")

        if data_source is None and halos_ds is not None:
            data_source = halos_ds.all_data()
        self.data_source = data_source

        if output_dir is None:
            if finder_method == "rockstar":
                output_dir = finder_kwargs.get("outbase", "rockstar_halos")
            else:
                output_dir = "halo_catalogs"

        self.output_basedir = ensure_dir(output_dir)
        self.pipeline = AnalysisPipeline(output_dir=self.output_dir)
        self.quantities = self.pipeline.quantities

        self.finder_method_name = finder_method
        if finder_kwargs is None:
            finder_kwargs = {}
        if finder_method is not None:
            finder_method = finding_method_registry.find(
                finder_method, **finder_kwargs)
        self.finder_method = finder_method

        self._add_default_quantities()
Esempio n. 10
0
    def __init__(self,
                 halos_ds=None,
                 data_ds=None,
                 data_source=None,
                 finder_method=None,
                 finder_kwargs=None,
                 output_dir="halo_catalogs/catalog"):
        ParallelAnalysisInterface.__init__(self)
        self.halos_ds = halos_ds
        self.data_ds = data_ds
        self.output_dir = ensure_dir(output_dir)
        if os.path.basename(self.output_dir) != ".":
            self.output_prefix = os.path.basename(self.output_dir)
        else:
            self.output_prefix = "catalog"

        if halos_ds is None:
            if data_ds is None:
                raise RuntimeError(
                    "Must specify a halos_ds, data_ds, or both.")
            if finder_method is None:
                raise RuntimeError(
                    "Must specify a halos_ds or a finder_method.")

        if data_source is None:
            if halos_ds is not None:
                halos_ds.index
                data_source = halos_ds.all_data()
            else:
                data_source = data_ds.all_data()
        self.data_source = data_source

        self.finder_method_name = finder_method
        if finder_kwargs is None:
            finder_kwargs = {}
        if finder_method is not None:
            finder_method = finding_method_registry.find(
                finder_method, **finder_kwargs)
        self.finder_method = finder_method

        # all of the analysis actions to be performed: callbacks, filters, and quantities
        self.actions = []
        # fields to be written to the halo catalog
        self.quantities = []
        if self.halos_ds is not None:
            self.add_default_quantities()
Esempio n. 11
0
    def _save(self, ds=None, data=None, extra_attrs=None, field_types=None):
        "Save new halo catalog."

        if ds is None:
            ds = self.source_ds
        else:
            self._source_ds = ds

        data_dir = ensure_dir(self.output_dir)
        filename = os.path.join(data_dir,
                                f"{self.output_basename}.{self.comm.rank}.h5")

        if data is None:
            n_halos = len(self.catalog)
            data = {}
            if n_halos > 0:
                for key in self.quantities:

                    if hasattr(self.catalog[0][key], "units"):
                        registry = self.catalog[0][key].units.registry
                        my_arr = functools.partial(unyt_array,
                                                   registry=registry)
                    else:
                        my_arr = np.array

                    data[key] = my_arr([halo[key] for halo in self.catalog])
        else:
            n_halos = data[self._id_field].size

        mylog.info("Saving %d halos: %s.", n_halos, filename)

        if field_types is None:
            field_types = {key: "." for key in self.quantities}

        if extra_attrs is None:
            extra_attrs = {}
        extra_attrs_d = {"data_type": "halo_catalog", "num_halos": n_halos}
        extra_attrs_d.update(extra_attrs)

        with quiet():
            save_as_dataset(ds,
                            filename,
                            data,
                            field_types=field_types,
                            extra_attrs=extra_attrs_d)
    def __init__(self, halos_ds=None, data_ds=None,
                 data_source=None, finder_method=None,
                 finder_kwargs=None,
                 output_dir="halo_catalogs/catalog"):
        ParallelAnalysisInterface.__init__(self)
        self.halos_ds = halos_ds
        self.data_ds = data_ds
        self.output_dir = ensure_dir(output_dir)
        if os.path.basename(self.output_dir) != ".":
            self.output_prefix = os.path.basename(self.output_dir)
        else:
            self.output_prefix = "catalog"

        if halos_ds is None:
            if data_ds is None:
                raise RuntimeError("Must specify a halos_ds, data_ds, or both.")
            if finder_method is None:
                raise RuntimeError("Must specify a halos_ds or a finder_method.")

        if data_source is None:
            if halos_ds is not None:
                halos_ds.index
                data_source = halos_ds.all_data()
            else:
                data_source = data_ds.all_data()
        self.data_source = data_source

        if finder_kwargs is None:
            finder_kwargs = {}
        if finder_method is not None:
            finder_method = finding_method_registry.find(finder_method,
                        **finder_kwargs)
        self.finder_method = finder_method

        # all of the analysis actions to be performed: callbacks, filters, and quantities
        self.actions = []
        # fields to be written to the halo catalog
        self.quantities = []
        if not self.halos_ds is None:
            self.add_default_quantities()
Esempio n. 13
0
    def save(self, name=None, suffix=None, mpl_kwargs=None):
        """saves the plot to disk.

        Parameters
        ----------
        name : string or tuple
           The base of the filename. If name is a directory or if name is not
           set, the filename of the dataset is used. For a tuple, the
           resulting path will be given by joining the elements of the
           tuple
        suffix : string
           Specify the image type by its suffix. If not specified, the output
           type will be inferred from the filename. Defaults to PNG.
        mpl_kwargs : dict
           A dict of keyword arguments to be passed to matplotlib.

        >>> slc.save(mpl_kwargs={'bbox_inches':'tight'})

        """
        names = []
        if mpl_kwargs is None:
            mpl_kwargs = {}
        if isinstance(name, (tuple, list)):
            name = os.path.join(*name)
        if name is None:
            name = str(self.ds)
        name = os.path.expanduser(name)
        if name[-1] == os.sep and not os.path.isdir(name):
            ensure_dir(name)
        if os.path.isdir(name) and name != str(self.ds):
            name = name + (os.sep if name[-1] != os.sep else "") + str(self.ds)
        if suffix is None:
            suffix = get_image_suffix(name)
            if suffix != "":
                for v in self.plots.values():
                    names.append(v.save(name, mpl_kwargs))
                return names
        if hasattr(self.data_source, "axis"):
            axis = self.ds.coordinates.axis_name.get(self.data_source.axis, "")
        else:
            axis = None
        weight = None
        type = self._plot_type
        if type in ["Projection", "OffAxisProjection"]:
            weight = self.data_source.weight_field
            if weight is not None:
                weight = weight[1].replace(" ", "_")
        if "Cutting" in self.data_source.__class__.__name__:
            type = "OffAxisSlice"
        for k, v in self.plots.items():
            if isinstance(k, tuple):
                k = k[1]
            if axis:
                n = f"{name}_{type}_{axis}_{k.replace(' ', '_')}"
            else:
                # for cutting planes
                n = f"{name}_{type}_{k.replace(' ', '_')}"
            if weight:
                n += f"_{weight}"
            if suffix != "":
                n = ".".join([n, suffix])
            names.append(v.save(n, mpl_kwargs))
        return names
Esempio n. 14
0
    assert_array_equal
import os
import shutil
import tempfile
from unittest import \
    TestCase
from yt.funcs import \
    ensure_dir
from yt.testing import \
    assert_rel_equal
from trident.config import \
    parse_config

# If TRIDENT_GENERATE_TEST_RESULTS=1, just generate test results.
generate_results = int(os.environ.get("TRIDENT_GENERATE_TEST_RESULTS", 0)) == 1
answer_test_data_dir = ensure_dir(
    os.path.abspath(os.path.expanduser(parse_config('answer_test_data_dir'))))
test_results_dir = os.path.join(answer_test_data_dir, "test_results")
if generate_results:
    ensure_dir(test_results_dir)


class TempDirTest(TestCase):
    """
    A test class that runs in a temporary directory and removes it afterward.
    """
    def setUp(self):
        self.curdir = os.getcwd()
        self.tmpdir = tempfile.mkdtemp()
        os.chdir(self.tmpdir)

    def tearDown(self):
Esempio n. 15
0
    fn = t.save_tree(filename=os.path.join(data_dir, "tree_%d" % uid))
    return fn


if __name__ == "__main__":
    ### Get the target tree here.
    # a_fn = "halo_2170858/tree_2170858/tree_2170858.h5"
    # if not os.path.exists(a_fn):
    #     get_my_tree()
    a_fn = get_my_tree_rare()

    a = ytree.load(a_fn)

    uid = a["uid"]
    data_dir = "halo_%d" % uid
    ensure_dir(data_dir)

    es = yt.load(sys.argv[1])
    exi, fns = get_existing_datasets(es, "%s")

    last_snap_idx = int(a[0]["Snap_idx"])
    idx_offset = fns.size - last_snap_idx - 1
    fields = ["n_black_holes", "BH_to_Eddington"]
    field_units = ["", "1/Msun"]
    for f, u in zip(fields, field_units):
        if f not in a.field_list:
            yt.mylog.info("Initializing analysis field: %s." % f)
            a.add_analysis_field(f, u)
            for t in a:
                for halo in t["tree"]:
                    halo[f] = -1
Esempio n. 16
0
    def trace_descendents(self, halo_type, fields=None, filename=None):
        """
        Trace the descendents of all halos.

        A merger-tree for all halos will be created, starting
        with the first halo catalog and moving forward.

        Parameters
        ----------
        halo_type : string
            The type of halo, typically "FOF" for FoF groups or
            "Subfind" for subhalos.
        fields : optional, list of strings
            List of additional fields to be saved to halo catalogs.
        filename : optional, string
            Directory in which merger-tree catalogs will be saved.
        """

        output_dir = os.path.dirname(filename)
        if self.comm.rank == 0 and len(output_dir) > 0:
            ensure_dir(output_dir)

        all_outputs = self.ts.outputs[:]
        ds1 = ds2 = None

        for i, fn2 in enumerate(all_outputs[1:]):
            fn1 = all_outputs[i]
            target_filename = get_output_filename(
                filename, "%s.%d" % (_get_tree_basename(fn1), 0), ".h5")
            catalog_filename = get_output_filename(
                filename, "%s.%d" % (_get_tree_basename(fn2), 0), ".h5")
            if os.path.exists(target_filename):
                continue

            if ds1 is None:
                ds1 = self._load_ds(fn1, index_ptype=halo_type)
            ds2 = self._load_ds(fn2, index_ptype=halo_type)

            if self.comm.rank == 0:
                _print_link_info(ds1, ds2)

            target_halos = []
            if ds1.index.particle_count[halo_type] == 0:
                self._save_catalog(filename, ds1, target_halos, fields)
                ds1 = ds2
                continue

            target_ids = \
              ds1.r[halo_type, "particle_identifier"].d.astype(np.int64)

            njobs = min(self.comm.size, target_ids.size)
            pbar = get_pbar("Linking halos", target_ids.size, parallel=True)
            my_i = 0
            for halo_id in parallel_objects(target_ids, njobs=njobs):
                my_halo = ds1.halo(halo_type, halo_id)

                target_halos.append(my_halo)
                self._find_descendent(my_halo, ds2)
                my_i += njobs
                pbar.update(my_i)
            pbar.finish()

            self._save_catalog(filename, ds1, target_halos, fields)
            ds1 = ds2
            clear_id_cache()

        if os.path.exists(catalog_filename):
            return

        if ds2 is None:
            ds2 = self._load_ds(fn2, index_ptype=halo_type)
        if self.comm.rank == 0:
            self._save_catalog(filename, ds2, halo_type, fields)
Esempio n. 17
0
    cfield_min = float(sys.argv[3])

    data_dir = os.path.join(
        a.filename.split("/")[-3], "clumps_%.0e" % cfield_min)
    last_snap_idx = int(a[0]["Snap_idx"])
    idx_offset = fns.size - last_snap_idx - 1

    for i, fn in yt.parallel_objects(enumerate(fns)):
        if not os.path.exists(fn):
            continue
        ds = yt.load(fn)

        c_min = ds.quan(cfield_min, "1/Msun")

        output_dir = os.path.join(data_dir, ds.directory)
        ensure_dir(output_dir)
        setup_ds = True

        search_idx = i - idx_offset
        sel = 'tree["tree", "Snap_idx"] == %d' % search_idx
        my_halos = a.select_halos(sel, fields=["Snap_idx"])

        yt.mylog.info("Finding clumps for %d halos." % my_halos.size)
        for halo in my_halos:
            if halo["phantom"] > 0:
                continue
            if halo["n_black_holes"] <= 0:
                continue
            output_filename = os.path.join(output_dir,
                                           "halo_%d_clumps.h5" % halo.uid)
            if os.path.exists(output_filename):
Esempio n. 18
0
    def trace_ancestors(self, halo_type, root_ids, fields=None, filename=None):
        """
        Trace the ancestry of a given set of halos.

        A merger-tree for a specific set of halos will be created,
        starting with the last halo catalog and moving backward.

        Parameters
        ----------
        halo_type : string
            The type of halo, typically "FOF" for FoF groups or
            "Subfind" for subhalos.
        root_ids : integer or array of integers
            The halo IDs from the last halo catalog for the
            targeted halos.
        fields : optional, list of strings
            List of additional fields to be saved to halo catalogs.
        filename : optional, string
            Directory in which merger-tree catalogs will be saved.
        """

        output_dir = os.path.dirname(filename)
        if self.comm.rank == 0 and len(output_dir) > 0:
            ensure_dir(output_dir)

        all_outputs = self.ts.outputs[::-1]
        ds1 = None

        for i, fn2 in enumerate(all_outputs[1:]):
            fn1 = all_outputs[i]
            target_filename = get_output_filename(
                filename, "%s.%d" % (_get_tree_basename(fn1), 0), ".h5")
            catalog_filename = get_output_filename(
                filename, "%s.%d" % (_get_tree_basename(fn2), 0), ".h5")
            if os.path.exists(catalog_filename):
                continue

            if ds1 is None:
                ds1 = self._load_ds(fn1, index_ptype=halo_type)
            ds2 = self._load_ds(fn2, index_ptype=halo_type)

            if self.comm.rank == 0:
                _print_link_info(ds1, ds2)

            if ds2.index.particle_count[halo_type] == 0:
                mylog.info("%s has no halos of type %s, ending." %
                           (ds2, halo_type))
                break

            if i == 0:
                target_ids = root_ids
                if not iterable(target_ids):
                    target_ids = np.array([target_ids])
                if isinstance(target_ids, YTArray):
                    target_ids = target_ids.d
                if target_ids.dtype != np.int64:
                    target_ids = target_ids.astype(np.int64)
            else:
                mylog.info("Loading target ids from %s.", target_filename)
                ds_target = yt_load(target_filename)
                target_ids = \
                  ds_target.r["halos",
                              "particle_identifier"].d.astype(np.int64)
                del ds_target

            id_store = []
            target_halos = []
            ancestor_halos = []

            njobs = min(self.comm.size, target_ids.size)
            pbar = get_pbar("Linking halos", target_ids.size, parallel=True)
            my_i = 0
            for halo_id in parallel_objects(target_ids, njobs=njobs):
                my_halo = ds1.halo(halo_type, halo_id)

                target_halos.append(my_halo)
                my_ancestors = self._find_ancestors(my_halo,
                                                    ds2,
                                                    id_store=id_store)
                ancestor_halos.extend(my_ancestors)
                my_i += njobs
                pbar.update(my_i)
            pbar.finish()

            if i == 0:
                for halo in target_halos:
                    halo.descendent_identifier = -1
                self._save_catalog(filename, ds1, target_halos, fields)
            self._save_catalog(filename, ds2, ancestor_halos, fields)

            if len(ancestor_halos) == 0:
                break

            ds1 = ds2
            clear_id_cache()
Esempio n. 19
0
from numpy.testing import assert_array_equal
from unittest import TestCase
from yt.funcs import ensure_dir
from yt.testing import assert_rel_equal

_base_file = os.path.basename(__file__)

# If GENERATE_TEST_RESULTS="true", just generate test results.
generate_results = os.environ.get("GENERATE_TEST_RESULTS", "false").lower() == "true"
yt.mylog.info(f"{_base_file}: {generate_results=}")

_results_dir = os.environ.get("TEST_RESULTS_DIR", "~/enzoe_test_results")
test_results_dir = os.path.abspath(os.path.expanduser(_results_dir))
yt.mylog.info(f"{_base_file}: {test_results_dir=}")
if generate_results:
    ensure_dir(test_results_dir)
else:
    if not os.path.exists(test_results_dir):
        raise RuntimeError(
            f"Test results directory not found: {test_results_dir}.")

# Set the path to charmrun
_charm_path = os.environ.get("CHARM_PATH", "")
if not _charm_path:
    raise RuntimeError(
        f"Specify path to charm with CHARM_PATH environment variable.")
charmrun_path = os.path.join(_charm_path, "charmrun")
yt.mylog.info(f"{_base_file}: {charmrun_path=}")
if not os.path.exists(charmrun_path):
    raise RuntimeError(
        f"No charmrun executable found in {_charm_path}.")
Esempio n. 20
0
    def save(self, name=None, suffix=".png", mpl_kwargs=None):
        """saves the plot to disk.

        Parameters
        ----------
        name : string or tuple
           The base of the filename. If name is a directory or if name is not
           set, the filename of the dataset is used. For a tuple, the
           resulting path will be given by joining the elements of the
           tuple
        suffix : string
           Specify the image type by its suffix. If not specified, the output
           type will be inferred from the filename. Defaults to PNG.
        mpl_kwargs : dict
           A dict of keyword arguments to be passed to matplotlib.

        >>> slc.save(mpl_kwargs={"bbox_inches": "tight"})

        """
        names = []
        if mpl_kwargs is None:
            mpl_kwargs = {}
        if isinstance(name, (tuple, list)):
            name = os.path.join(*name)
        if name is None:
            name = str(self.ds)
        name = os.path.expanduser(name)
        if name[-1] == os.sep and not os.path.isdir(name):
            ensure_dir(name)
        if os.path.isdir(name) and name != str(self.ds):
            name = name + (os.sep if name[-1] != os.sep else "") + str(self.ds)

        new_name = validate_image_name(name, suffix)
        if new_name == name:
            for v in self.plots.values():
                out_name = v.save(name, mpl_kwargs)
                names.append(out_name)
            return names

        name = new_name
        prefix, suffix = os.path.splitext(name)

        if hasattr(self.data_source, "axis"):
            axis = self.ds.coordinates.axis_name.get(self.data_source.axis, "")
        else:
            axis = None
        weight = None
        plot_type = self._plot_type
        if plot_type in ["Projection", "OffAxisProjection"]:
            weight = self.data_source.weight_field
            if weight is not None:
                weight = weight[1].replace(" ", "_")
        if "Cutting" in self.data_source.__class__.__name__:
            plot_type = "OffAxisSlice"
        for k, v in self.plots.items():
            if isinstance(k, tuple):
                k = k[1]

            name_elements = [prefix, plot_type]
            if axis:
                name_elements.append(axis)
            name_elements.append(k.replace(" ", "_"))
            if weight:
                name_elements.append(weight)

            name = "_".join(name_elements) + suffix
            names.append(v.save(name, mpl_kwargs))
        return names
Esempio n. 21
0
    def save(
        self,
        name: Optional[Union[str, List[str], Tuple[str, ...]]] = None,
        suffix: Optional[str] = None,
        mpl_kwargs: Optional[Dict[str, Any]] = None,
    ):
        """saves the plot to disk.

        Parameters
        ----------
        name : string or tuple, optional
           The base of the filename. If name is a directory or if name is not
           set, the filename of the dataset is used. For a tuple, the
           resulting path will be given by joining the elements of the
           tuple
        suffix : string, optional
           Specify the image type by its suffix. If not specified, the output
           type will be inferred from the filename. Defaults to '.png'.
        mpl_kwargs : dict, optional
           A dict of keyword arguments to be passed to matplotlib.

        >>> slc.save(mpl_kwargs={"bbox_inches": "tight"})

        """
        names = []
        if mpl_kwargs is None:
            mpl_kwargs = {}
        elif "format" in mpl_kwargs:
            new_suffix = mpl_kwargs.pop("format")
            if new_suffix != suffix:
                warnings.warn(
                    f"Overriding suffix {suffix!r} with mpl_kwargs['format'] = {new_suffix!r}. "
                    "Use the `suffix` argument directly to suppress this warning."
                )
            suffix = new_suffix

        if name is None:
            name = str(self.ds)
        elif isinstance(name, (list, tuple)):
            if not all(isinstance(_, str) for _ in name):
                raise TypeError(
                    f"Expected a single str or an iterable of str, got {name!r}"
                )
            name = os.path.join(*name)

        name = os.path.expanduser(name)

        parent_dir, _, prefix1 = name.replace(os.sep, "/").rpartition("/")
        parent_dir = parent_dir.replace("/", os.sep)

        if parent_dir and not os.path.isdir(parent_dir):
            ensure_dir(parent_dir)

        if name.endswith(("/", os.path.sep)):
            name = os.path.join(name, str(self.ds))

        new_name = validate_image_name(name, suffix)
        if new_name == name:
            # somehow mypy thinks we may not have a plots attr yet, hence we turn it off here
            for v in self.plots.values():  # type: ignore
                out_name = v.save(name, mpl_kwargs)
                names.append(out_name)
            return names

        name = new_name
        prefix, suffix = os.path.splitext(name)

        if hasattr(self.data_source, "axis"):
            axis = self.ds.coordinates.axis_name.get(self.data_source.axis, "")
        else:
            axis = None
        weight = None
        plot_type = self._plot_type
        if plot_type in ["Projection", "OffAxisProjection"]:
            weight = self.data_source.weight_field
            if weight is not None:
                weight = weight[1].replace(" ", "_")
        if "Cutting" in self.data_source.__class__.__name__:
            plot_type = "OffAxisSlice"

        # somehow mypy thinks we may not have a plots attr yet, hence we turn it off here
        for k, v in self.plots.items():  # type: ignore
            if isinstance(k, tuple):
                k = k[1]

            if plot_type is None:
                # implemented this check to make mypy happy, because we can't use str.join
                # with PlotContainer._plot_type = None
                raise TypeError(
                    f"{self.__class__} is missing a _plot_type value (str)")

            name_elements = [prefix, plot_type]
            if axis:
                name_elements.append(axis)
            name_elements.append(k.replace(" ", "_"))
            if weight:
                name_elements.append(weight)

            name = "_".join(name_elements) + suffix
            names.append(v.save(name, mpl_kwargs))
        return names
Esempio n. 22
0
    def save_arbor(self,
                   filename="arbor",
                   fields=None,
                   trees=None,
                   max_file_size=524288):
        r"""
        Save the arbor to a file.

        The saved arbor can be re-loaded as an arbor.

        Parameters
        ----------
        filename : optional, string
            Output file keyword.  Main header file will be named
            <filename>/<filename>.h5.
            Default: "arbor".
        fields : optional, list of strings
            The fields to be saved.  If not given, all
            fields will be saved.

        Returns
        -------
        filename : string
            The filename of the saved arbor.

        Examples
        --------

        >>> import ytree
        >>> a = ytree.load("rockstar_halos/trees/tree_0_0_0.dat")
        >>> fn = a.save_arbor()
        >>> # reload it
        >>> a2 = ytree.load(fn)

        """

        if trees is None:
            all_trees = True
            trees = self.trees
            roots = trees
        else:
            all_trees = False
            # assemble unique tree roots for getting fields
            trees = np.asarray(trees)
            roots = []
            root_uids = []
            for tree in trees:
                if tree.root == -1:
                    my_root = tree
                else:
                    my_root = tree.root
                if my_root.uid not in root_uids:
                    roots.append(my_root)
                    root_uids.append(my_root.uid)
            roots = np.array(roots)
            del root_uids

        if fields in [None, "all"]:
            # If a field has an alias, get that instead.
            fields = []
            for field in self.field_list + self.analysis_field_list:
                fields.extend(self.field_info[field].get("aliases", [field]))
        else:
            fields.extend([f for f in ["uid", "desc_uid"] if f not in fields])

        ds = {}
        for attr in ["hubble_constant", "omega_matter", "omega_lambda"]:
            if hasattr(self, attr):
                ds[attr] = getattr(self, attr)
        extra_attrs = {
            "box_size": self.box_size,
            "arbor_type": "YTreeArbor",
            "unit_registry_json": self.unit_registry.to_json()
        }

        self._node_io_loop(self._setup_tree,
                           root_nodes=roots,
                           pbar="Setting up trees")
        self._root_io.get_fields(self, fields=fields)

        # determine file layout
        nn = 0  # node count
        nt = 0  # tree count
        nnodes = []
        ntrees = []
        tree_size = np.array([tree.tree_size for tree in trees])
        for ts in tree_size:
            nn += ts
            nt += 1
            if nn > max_file_size:
                nnodes.append(nn - ts)
                ntrees.append(nt - 1)
                nn = ts
                nt = 1
        if nn > 0:
            nnodes.append(nn)
            ntrees.append(nt)
        nfiles = len(nnodes)
        nnodes = np.array(nnodes)
        ntrees = np.array(ntrees)
        tree_end_index = ntrees.cumsum()
        tree_start_index = tree_end_index - ntrees

        # write header file
        fieldnames = [field.replace("/", "_") for field in fields]
        myfi = {}
        rdata = {}
        rtypes = {}
        for field, fieldname in zip(fields, fieldnames):
            fi = self.field_info[field]
            myfi[fieldname] = \
              dict((key, fi[key])
                   for key in ["units", "description"]
                   if key in fi)
            if all_trees:
                rdata[fieldname] = self._root_field_data[field]
            else:
                rdata[fieldname] = self.arr([t[field] for t in trees])
            rtypes[fieldname] = "data"
        # all saved trees will be roots
        if not all_trees:
            rdata["desc_uid"][:] = -1
        extra_attrs["field_info"] = json.dumps(myfi)
        extra_attrs["total_files"] = nfiles
        extra_attrs["total_trees"] = trees.size
        extra_attrs["total_nodes"] = tree_size.sum()
        hdata = {
            "tree_start_index": tree_start_index,
            "tree_end_index": tree_end_index,
            "tree_size": ntrees
        }
        hdata.update(rdata)
        htypes = dict((f, "index") for f in hdata)
        htypes.update(rtypes)

        ensure_dir(filename)
        header_filename = os.path.join(filename, "%s.h5" % filename)
        save_as_dataset(ds,
                        header_filename,
                        hdata,
                        field_types=htypes,
                        extra_attrs=extra_attrs)

        # write data files
        ftypes = dict((f, "data") for f in fieldnames)
        for i in range(nfiles):
            my_nodes = trees[tree_start_index[i]:tree_end_index[i]]
            self._node_io_loop(self._node_io.get_fields,
                               pbar="Getting fields [%d/%d]" % (i + 1, nfiles),
                               root_nodes=my_nodes,
                               fields=fields,
                               root_only=False)
            fdata = dict((field, np.empty(nnodes[i])) for field in fieldnames)
            my_tree_size = tree_size[tree_start_index[i]:tree_end_index[i]]
            my_tree_end = my_tree_size.cumsum()
            my_tree_start = my_tree_end - my_tree_size
            pbar = get_pbar("Creating field arrays [%d/%d]" % (i + 1, nfiles),
                            len(fields) * nnodes[i])
            c = 0
            for field, fieldname in zip(fields, fieldnames):
                for di, node in enumerate(my_nodes):
                    if node.is_root:
                        ndata = node._tree_field_data[field]
                    else:
                        ndata = node["tree", field]
                        if field == "desc_uid":
                            # make sure it's a root when loaded
                            ndata[0] = -1
                    fdata[fieldname][my_tree_start[di]:my_tree_end[di]] = ndata
                    c += my_tree_size[di]
                    pbar.update(c)
            pbar.finish()
            fdata["tree_start_index"] = my_tree_start
            fdata["tree_end_index"] = my_tree_end
            fdata["tree_size"] = my_tree_size
            for ft in ["tree_start_index", "tree_end_index", "tree_size"]:
                ftypes[ft] = "index"
            my_filename = os.path.join(filename, "%s_%04d.h5" % (filename, i))
            save_as_dataset({}, my_filename, fdata, field_types=ftypes)

        return header_filename