def __call__(self, args): if args.list: self.get_list() return if not args.filename: raise RuntimeError("You need to provide a filename. See --help " "for details or use --list to get available " "datasets.") elif not args.location: raise RuntimeError( "You need to specify download location. See --help for details." ) data_url = f"http://yt-project.org/data/{args.filename}" if args.location in ["test_data_dir", "supp_data_dir"]: data_dir = ytcfg.get("yt", args.location) if data_dir == "/does/not/exist": raise RuntimeError(f"'{args.location}' is not configured!") else: data_dir = args.location if not os.path.exists(data_dir): print(f"The directory '{data_dir}' does not exist. Creating...") ensure_dir(data_dir) data_file = os.path.join(data_dir, args.filename) if os.path.exists(data_file) and not args.overwrite: raise OSError(f"File '{data_file}' exists and overwrite=False!") print(f"Attempting to download file: {args.filename}") fn = download_file(data_url, data_file) if not os.path.exists(fn): raise OSError(f"The file '{args.filename}' did not download!!") print(f"File: {args.filename} downloaded successfully to {data_file}")
def get_my_tree_rare(): """ Halo with the most black holes from the normal region. """ ### Get the tree with the most black holes. ds = yt.load("halo_catalogs_nosub/RD0041/RD0041.0.h5") a = ytree.load("rockstar_halos/trees/tree_0_0_0.dat") # halo with highest relative growth gsort = ds.r["relative_growth_max"].argsort()[::-1] hid = ds.r["particle_identifier"][gsort][0].d # # halo with second highest relative growth # gsort = ds.r["relative_growth_max"].argsort()[::-1] # hid = ds.r["particle_identifier"][gsort][1].d t = a[a["Orig_halo_ID"] == hid][0] uid = t["uid"] data_dir = "halo_%d" % uid ensure_dir(data_dir) fn = t.save_tree(filename=os.path.join(data_dir, "tree_%d" % uid)) return fn
def add_callback(self, callback, *args, **kwargs): r""" Add a callback to the halo catalog action list. A callback is a function that accepts and operates on a Halo object and does not return anything. Callbacks must exist within the callback_registry. Give additional args and kwargs to be passed to the callback here. Parameters ---------- callback : string The name of the callback. Examples -------- >>> # Here, a callback is defined and added to the registry. >>> def _say_something(halo, message): ... my_id = halo.quantities['particle_identifier'] ... print "Halo %d: here is a message - %s." % (my_id, message) >>> add_callback("hello_world", _say_something) >>> # Now this callback is accessible to the HaloCatalog object >>> hc.add_callback("hello_world", "this is my message") """ callback = callback_registry.find(callback, *args, **kwargs) if "output_dir" in kwargs is not None: ensure_dir(os.path.join(self.output_dir, kwargs["output_dir"])) self.actions.append(("callback", callback))
def __init__(self, parameter_filename, simulation_type, near_redshift, far_redshift, observer_redshift=0.0, use_minimum_datasets=True, deltaz_min=0.0, minimum_coherent_box_fraction=0.0, time_data=True, redshift_data=True, find_outputs=False, set_parameters=None, output_dir="LC", output_prefix="LightCone"): self.near_redshift = near_redshift self.far_redshift = far_redshift self.observer_redshift = observer_redshift self.use_minimum_datasets = use_minimum_datasets self.deltaz_min = deltaz_min self.minimum_coherent_box_fraction = minimum_coherent_box_fraction if set_parameters is None: self.set_parameters = {} else: self.set_parameters = set_parameters self.output_dir = output_dir self.output_prefix = output_prefix # Create output directory. ensure_dir(self.output_dir) # Calculate light cone solution. CosmologySplice.__init__(self, parameter_filename, simulation_type, find_outputs=find_outputs) self.light_cone_solution = \ self.create_cosmology_splice(self.near_redshift, self.far_redshift, minimal=self.use_minimum_datasets, deltaz_min=self.deltaz_min, time_data=time_data, redshift_data=redshift_data)
def save(self, name=None, suffix=None, mpl_kwargs=None): """saves the plot to disk. Parameters ---------- name : string The base of the filename. If name is a directory or if name is not set, the filename of the dataset is used. suffix : string Specify the image type by its suffix. If not specified, the output type will be inferred from the filename. Defaults to PNG. mpl_kwargs : dict A dict of keyword arguments to be passed to matplotlib. >>> slc.save(mpl_kwargs={'bbox_inches':'tight'}) """ names = [] if mpl_kwargs is None: mpl_kwargs = {} if name is None: name = str(self.ds) name = os.path.expanduser(name) if name[-1] == os.sep and not os.path.isdir(name): ensure_dir(name) if os.path.isdir(name) and name != str(self.ds): name = name + (os.sep if name[-1] != os.sep else '') + str(self.ds) if suffix is None: suffix = get_image_suffix(name) if suffix != '': for k, v in iteritems(self.plots): names.append(v.save(name, mpl_kwargs)) return names if hasattr(self.data_source, 'axis'): axis = self.ds.coordinates.axis_name.get( self.data_source.axis, '') else: axis = None weight = None type = self._plot_type if type in ['Projection', 'OffAxisProjection']: weight = self.data_source.weight_field if weight is not None: weight = weight[1].replace(' ', '_') if 'Cutting' in self.data_source.__class__.__name__: type = 'OffAxisSlice' for k, v in iteritems(self.plots): if isinstance(k, tuple): k = k[1] if axis: n = "%s_%s_%s_%s" % (name, type, axis, k.replace(' ', '_')) else: # for cutting planes n = "%s_%s_%s" % (name, type, k.replace(' ', '_')) if weight: n += "_%s" % (weight) if suffix != '': n = ".".join([n,suffix]) names.append(v.save(n, mpl_kwargs)) return names
def _determine_output_filename(path, suffix): if path.endswith(suffix): dirname = os.path.dirname(path) filename = path[:-len(suffix)] else: dirname = path filename = os.path.join(dirname, os.path.basename(path)) ensure_dir(dirname) return filename
def get_my_tree(): """ Halo with the most black holes from the normal region. """ ### Get the tree with the most black holes. ds = yt.load("halo_catalogs/RD0077/RD0077.0.h5") a = ytree.load("rockstar_halos/trees/tree_0_0_0.dat") i = ds.r["n_black_holes"].argmax() hid = ds.r["particle_identifier"][i].d t = a[a["Orig_halo_ID"] == hid][0] uid = t["uid"] data_dir = "halo_%d" % uid ensure_dir(data_dir) fn = t.save_tree(filename=os.path.join(data_dir, "tree_%d" % uid)) return fn
def __init__( self, halos_ds=None, data_ds=None, data_source=None, halo_field_type="all", finder_method=None, finder_kwargs=None, output_dir=None, ): super().__init__() self.halos_ds = halos_ds self.data_ds = data_ds self.halo_field_type = halo_field_type if halos_ds is None: if data_ds is None: raise RuntimeError( "Must specify a halos_ds, data_ds, or both.") if finder_method is None: raise RuntimeError( "Must specify a halos_ds or a finder_method.") if data_source is None and halos_ds is not None: data_source = halos_ds.all_data() self.data_source = data_source if output_dir is None: if finder_method == "rockstar": output_dir = finder_kwargs.get("outbase", "rockstar_halos") else: output_dir = "halo_catalogs" self.output_basedir = ensure_dir(output_dir) self.pipeline = AnalysisPipeline(output_dir=self.output_dir) self.quantities = self.pipeline.quantities self.finder_method_name = finder_method if finder_kwargs is None: finder_kwargs = {} if finder_method is not None: finder_method = finding_method_registry.find( finder_method, **finder_kwargs) self.finder_method = finder_method self._add_default_quantities()
def __init__(self, halos_ds=None, data_ds=None, data_source=None, finder_method=None, finder_kwargs=None, output_dir="halo_catalogs/catalog"): ParallelAnalysisInterface.__init__(self) self.halos_ds = halos_ds self.data_ds = data_ds self.output_dir = ensure_dir(output_dir) if os.path.basename(self.output_dir) != ".": self.output_prefix = os.path.basename(self.output_dir) else: self.output_prefix = "catalog" if halos_ds is None: if data_ds is None: raise RuntimeError( "Must specify a halos_ds, data_ds, or both.") if finder_method is None: raise RuntimeError( "Must specify a halos_ds or a finder_method.") if data_source is None: if halos_ds is not None: halos_ds.index data_source = halos_ds.all_data() else: data_source = data_ds.all_data() self.data_source = data_source self.finder_method_name = finder_method if finder_kwargs is None: finder_kwargs = {} if finder_method is not None: finder_method = finding_method_registry.find( finder_method, **finder_kwargs) self.finder_method = finder_method # all of the analysis actions to be performed: callbacks, filters, and quantities self.actions = [] # fields to be written to the halo catalog self.quantities = [] if self.halos_ds is not None: self.add_default_quantities()
def _save(self, ds=None, data=None, extra_attrs=None, field_types=None): "Save new halo catalog." if ds is None: ds = self.source_ds else: self._source_ds = ds data_dir = ensure_dir(self.output_dir) filename = os.path.join(data_dir, f"{self.output_basename}.{self.comm.rank}.h5") if data is None: n_halos = len(self.catalog) data = {} if n_halos > 0: for key in self.quantities: if hasattr(self.catalog[0][key], "units"): registry = self.catalog[0][key].units.registry my_arr = functools.partial(unyt_array, registry=registry) else: my_arr = np.array data[key] = my_arr([halo[key] for halo in self.catalog]) else: n_halos = data[self._id_field].size mylog.info("Saving %d halos: %s.", n_halos, filename) if field_types is None: field_types = {key: "." for key in self.quantities} if extra_attrs is None: extra_attrs = {} extra_attrs_d = {"data_type": "halo_catalog", "num_halos": n_halos} extra_attrs_d.update(extra_attrs) with quiet(): save_as_dataset(ds, filename, data, field_types=field_types, extra_attrs=extra_attrs_d)
def __init__(self, halos_ds=None, data_ds=None, data_source=None, finder_method=None, finder_kwargs=None, output_dir="halo_catalogs/catalog"): ParallelAnalysisInterface.__init__(self) self.halos_ds = halos_ds self.data_ds = data_ds self.output_dir = ensure_dir(output_dir) if os.path.basename(self.output_dir) != ".": self.output_prefix = os.path.basename(self.output_dir) else: self.output_prefix = "catalog" if halos_ds is None: if data_ds is None: raise RuntimeError("Must specify a halos_ds, data_ds, or both.") if finder_method is None: raise RuntimeError("Must specify a halos_ds or a finder_method.") if data_source is None: if halos_ds is not None: halos_ds.index data_source = halos_ds.all_data() else: data_source = data_ds.all_data() self.data_source = data_source if finder_kwargs is None: finder_kwargs = {} if finder_method is not None: finder_method = finding_method_registry.find(finder_method, **finder_kwargs) self.finder_method = finder_method # all of the analysis actions to be performed: callbacks, filters, and quantities self.actions = [] # fields to be written to the halo catalog self.quantities = [] if not self.halos_ds is None: self.add_default_quantities()
def save(self, name=None, suffix=None, mpl_kwargs=None): """saves the plot to disk. Parameters ---------- name : string or tuple The base of the filename. If name is a directory or if name is not set, the filename of the dataset is used. For a tuple, the resulting path will be given by joining the elements of the tuple suffix : string Specify the image type by its suffix. If not specified, the output type will be inferred from the filename. Defaults to PNG. mpl_kwargs : dict A dict of keyword arguments to be passed to matplotlib. >>> slc.save(mpl_kwargs={'bbox_inches':'tight'}) """ names = [] if mpl_kwargs is None: mpl_kwargs = {} if isinstance(name, (tuple, list)): name = os.path.join(*name) if name is None: name = str(self.ds) name = os.path.expanduser(name) if name[-1] == os.sep and not os.path.isdir(name): ensure_dir(name) if os.path.isdir(name) and name != str(self.ds): name = name + (os.sep if name[-1] != os.sep else "") + str(self.ds) if suffix is None: suffix = get_image_suffix(name) if suffix != "": for v in self.plots.values(): names.append(v.save(name, mpl_kwargs)) return names if hasattr(self.data_source, "axis"): axis = self.ds.coordinates.axis_name.get(self.data_source.axis, "") else: axis = None weight = None type = self._plot_type if type in ["Projection", "OffAxisProjection"]: weight = self.data_source.weight_field if weight is not None: weight = weight[1].replace(" ", "_") if "Cutting" in self.data_source.__class__.__name__: type = "OffAxisSlice" for k, v in self.plots.items(): if isinstance(k, tuple): k = k[1] if axis: n = f"{name}_{type}_{axis}_{k.replace(' ', '_')}" else: # for cutting planes n = f"{name}_{type}_{k.replace(' ', '_')}" if weight: n += f"_{weight}" if suffix != "": n = ".".join([n, suffix]) names.append(v.save(n, mpl_kwargs)) return names
assert_array_equal import os import shutil import tempfile from unittest import \ TestCase from yt.funcs import \ ensure_dir from yt.testing import \ assert_rel_equal from trident.config import \ parse_config # If TRIDENT_GENERATE_TEST_RESULTS=1, just generate test results. generate_results = int(os.environ.get("TRIDENT_GENERATE_TEST_RESULTS", 0)) == 1 answer_test_data_dir = ensure_dir( os.path.abspath(os.path.expanduser(parse_config('answer_test_data_dir')))) test_results_dir = os.path.join(answer_test_data_dir, "test_results") if generate_results: ensure_dir(test_results_dir) class TempDirTest(TestCase): """ A test class that runs in a temporary directory and removes it afterward. """ def setUp(self): self.curdir = os.getcwd() self.tmpdir = tempfile.mkdtemp() os.chdir(self.tmpdir) def tearDown(self):
fn = t.save_tree(filename=os.path.join(data_dir, "tree_%d" % uid)) return fn if __name__ == "__main__": ### Get the target tree here. # a_fn = "halo_2170858/tree_2170858/tree_2170858.h5" # if not os.path.exists(a_fn): # get_my_tree() a_fn = get_my_tree_rare() a = ytree.load(a_fn) uid = a["uid"] data_dir = "halo_%d" % uid ensure_dir(data_dir) es = yt.load(sys.argv[1]) exi, fns = get_existing_datasets(es, "%s") last_snap_idx = int(a[0]["Snap_idx"]) idx_offset = fns.size - last_snap_idx - 1 fields = ["n_black_holes", "BH_to_Eddington"] field_units = ["", "1/Msun"] for f, u in zip(fields, field_units): if f not in a.field_list: yt.mylog.info("Initializing analysis field: %s." % f) a.add_analysis_field(f, u) for t in a: for halo in t["tree"]: halo[f] = -1
def trace_descendents(self, halo_type, fields=None, filename=None): """ Trace the descendents of all halos. A merger-tree for all halos will be created, starting with the first halo catalog and moving forward. Parameters ---------- halo_type : string The type of halo, typically "FOF" for FoF groups or "Subfind" for subhalos. fields : optional, list of strings List of additional fields to be saved to halo catalogs. filename : optional, string Directory in which merger-tree catalogs will be saved. """ output_dir = os.path.dirname(filename) if self.comm.rank == 0 and len(output_dir) > 0: ensure_dir(output_dir) all_outputs = self.ts.outputs[:] ds1 = ds2 = None for i, fn2 in enumerate(all_outputs[1:]): fn1 = all_outputs[i] target_filename = get_output_filename( filename, "%s.%d" % (_get_tree_basename(fn1), 0), ".h5") catalog_filename = get_output_filename( filename, "%s.%d" % (_get_tree_basename(fn2), 0), ".h5") if os.path.exists(target_filename): continue if ds1 is None: ds1 = self._load_ds(fn1, index_ptype=halo_type) ds2 = self._load_ds(fn2, index_ptype=halo_type) if self.comm.rank == 0: _print_link_info(ds1, ds2) target_halos = [] if ds1.index.particle_count[halo_type] == 0: self._save_catalog(filename, ds1, target_halos, fields) ds1 = ds2 continue target_ids = \ ds1.r[halo_type, "particle_identifier"].d.astype(np.int64) njobs = min(self.comm.size, target_ids.size) pbar = get_pbar("Linking halos", target_ids.size, parallel=True) my_i = 0 for halo_id in parallel_objects(target_ids, njobs=njobs): my_halo = ds1.halo(halo_type, halo_id) target_halos.append(my_halo) self._find_descendent(my_halo, ds2) my_i += njobs pbar.update(my_i) pbar.finish() self._save_catalog(filename, ds1, target_halos, fields) ds1 = ds2 clear_id_cache() if os.path.exists(catalog_filename): return if ds2 is None: ds2 = self._load_ds(fn2, index_ptype=halo_type) if self.comm.rank == 0: self._save_catalog(filename, ds2, halo_type, fields)
cfield_min = float(sys.argv[3]) data_dir = os.path.join( a.filename.split("/")[-3], "clumps_%.0e" % cfield_min) last_snap_idx = int(a[0]["Snap_idx"]) idx_offset = fns.size - last_snap_idx - 1 for i, fn in yt.parallel_objects(enumerate(fns)): if not os.path.exists(fn): continue ds = yt.load(fn) c_min = ds.quan(cfield_min, "1/Msun") output_dir = os.path.join(data_dir, ds.directory) ensure_dir(output_dir) setup_ds = True search_idx = i - idx_offset sel = 'tree["tree", "Snap_idx"] == %d' % search_idx my_halos = a.select_halos(sel, fields=["Snap_idx"]) yt.mylog.info("Finding clumps for %d halos." % my_halos.size) for halo in my_halos: if halo["phantom"] > 0: continue if halo["n_black_holes"] <= 0: continue output_filename = os.path.join(output_dir, "halo_%d_clumps.h5" % halo.uid) if os.path.exists(output_filename):
def trace_ancestors(self, halo_type, root_ids, fields=None, filename=None): """ Trace the ancestry of a given set of halos. A merger-tree for a specific set of halos will be created, starting with the last halo catalog and moving backward. Parameters ---------- halo_type : string The type of halo, typically "FOF" for FoF groups or "Subfind" for subhalos. root_ids : integer or array of integers The halo IDs from the last halo catalog for the targeted halos. fields : optional, list of strings List of additional fields to be saved to halo catalogs. filename : optional, string Directory in which merger-tree catalogs will be saved. """ output_dir = os.path.dirname(filename) if self.comm.rank == 0 and len(output_dir) > 0: ensure_dir(output_dir) all_outputs = self.ts.outputs[::-1] ds1 = None for i, fn2 in enumerate(all_outputs[1:]): fn1 = all_outputs[i] target_filename = get_output_filename( filename, "%s.%d" % (_get_tree_basename(fn1), 0), ".h5") catalog_filename = get_output_filename( filename, "%s.%d" % (_get_tree_basename(fn2), 0), ".h5") if os.path.exists(catalog_filename): continue if ds1 is None: ds1 = self._load_ds(fn1, index_ptype=halo_type) ds2 = self._load_ds(fn2, index_ptype=halo_type) if self.comm.rank == 0: _print_link_info(ds1, ds2) if ds2.index.particle_count[halo_type] == 0: mylog.info("%s has no halos of type %s, ending." % (ds2, halo_type)) break if i == 0: target_ids = root_ids if not iterable(target_ids): target_ids = np.array([target_ids]) if isinstance(target_ids, YTArray): target_ids = target_ids.d if target_ids.dtype != np.int64: target_ids = target_ids.astype(np.int64) else: mylog.info("Loading target ids from %s.", target_filename) ds_target = yt_load(target_filename) target_ids = \ ds_target.r["halos", "particle_identifier"].d.astype(np.int64) del ds_target id_store = [] target_halos = [] ancestor_halos = [] njobs = min(self.comm.size, target_ids.size) pbar = get_pbar("Linking halos", target_ids.size, parallel=True) my_i = 0 for halo_id in parallel_objects(target_ids, njobs=njobs): my_halo = ds1.halo(halo_type, halo_id) target_halos.append(my_halo) my_ancestors = self._find_ancestors(my_halo, ds2, id_store=id_store) ancestor_halos.extend(my_ancestors) my_i += njobs pbar.update(my_i) pbar.finish() if i == 0: for halo in target_halos: halo.descendent_identifier = -1 self._save_catalog(filename, ds1, target_halos, fields) self._save_catalog(filename, ds2, ancestor_halos, fields) if len(ancestor_halos) == 0: break ds1 = ds2 clear_id_cache()
from numpy.testing import assert_array_equal from unittest import TestCase from yt.funcs import ensure_dir from yt.testing import assert_rel_equal _base_file = os.path.basename(__file__) # If GENERATE_TEST_RESULTS="true", just generate test results. generate_results = os.environ.get("GENERATE_TEST_RESULTS", "false").lower() == "true" yt.mylog.info(f"{_base_file}: {generate_results=}") _results_dir = os.environ.get("TEST_RESULTS_DIR", "~/enzoe_test_results") test_results_dir = os.path.abspath(os.path.expanduser(_results_dir)) yt.mylog.info(f"{_base_file}: {test_results_dir=}") if generate_results: ensure_dir(test_results_dir) else: if not os.path.exists(test_results_dir): raise RuntimeError( f"Test results directory not found: {test_results_dir}.") # Set the path to charmrun _charm_path = os.environ.get("CHARM_PATH", "") if not _charm_path: raise RuntimeError( f"Specify path to charm with CHARM_PATH environment variable.") charmrun_path = os.path.join(_charm_path, "charmrun") yt.mylog.info(f"{_base_file}: {charmrun_path=}") if not os.path.exists(charmrun_path): raise RuntimeError( f"No charmrun executable found in {_charm_path}.")
def save(self, name=None, suffix=".png", mpl_kwargs=None): """saves the plot to disk. Parameters ---------- name : string or tuple The base of the filename. If name is a directory or if name is not set, the filename of the dataset is used. For a tuple, the resulting path will be given by joining the elements of the tuple suffix : string Specify the image type by its suffix. If not specified, the output type will be inferred from the filename. Defaults to PNG. mpl_kwargs : dict A dict of keyword arguments to be passed to matplotlib. >>> slc.save(mpl_kwargs={"bbox_inches": "tight"}) """ names = [] if mpl_kwargs is None: mpl_kwargs = {} if isinstance(name, (tuple, list)): name = os.path.join(*name) if name is None: name = str(self.ds) name = os.path.expanduser(name) if name[-1] == os.sep and not os.path.isdir(name): ensure_dir(name) if os.path.isdir(name) and name != str(self.ds): name = name + (os.sep if name[-1] != os.sep else "") + str(self.ds) new_name = validate_image_name(name, suffix) if new_name == name: for v in self.plots.values(): out_name = v.save(name, mpl_kwargs) names.append(out_name) return names name = new_name prefix, suffix = os.path.splitext(name) if hasattr(self.data_source, "axis"): axis = self.ds.coordinates.axis_name.get(self.data_source.axis, "") else: axis = None weight = None plot_type = self._plot_type if plot_type in ["Projection", "OffAxisProjection"]: weight = self.data_source.weight_field if weight is not None: weight = weight[1].replace(" ", "_") if "Cutting" in self.data_source.__class__.__name__: plot_type = "OffAxisSlice" for k, v in self.plots.items(): if isinstance(k, tuple): k = k[1] name_elements = [prefix, plot_type] if axis: name_elements.append(axis) name_elements.append(k.replace(" ", "_")) if weight: name_elements.append(weight) name = "_".join(name_elements) + suffix names.append(v.save(name, mpl_kwargs)) return names
def save( self, name: Optional[Union[str, List[str], Tuple[str, ...]]] = None, suffix: Optional[str] = None, mpl_kwargs: Optional[Dict[str, Any]] = None, ): """saves the plot to disk. Parameters ---------- name : string or tuple, optional The base of the filename. If name is a directory or if name is not set, the filename of the dataset is used. For a tuple, the resulting path will be given by joining the elements of the tuple suffix : string, optional Specify the image type by its suffix. If not specified, the output type will be inferred from the filename. Defaults to '.png'. mpl_kwargs : dict, optional A dict of keyword arguments to be passed to matplotlib. >>> slc.save(mpl_kwargs={"bbox_inches": "tight"}) """ names = [] if mpl_kwargs is None: mpl_kwargs = {} elif "format" in mpl_kwargs: new_suffix = mpl_kwargs.pop("format") if new_suffix != suffix: warnings.warn( f"Overriding suffix {suffix!r} with mpl_kwargs['format'] = {new_suffix!r}. " "Use the `suffix` argument directly to suppress this warning." ) suffix = new_suffix if name is None: name = str(self.ds) elif isinstance(name, (list, tuple)): if not all(isinstance(_, str) for _ in name): raise TypeError( f"Expected a single str or an iterable of str, got {name!r}" ) name = os.path.join(*name) name = os.path.expanduser(name) parent_dir, _, prefix1 = name.replace(os.sep, "/").rpartition("/") parent_dir = parent_dir.replace("/", os.sep) if parent_dir and not os.path.isdir(parent_dir): ensure_dir(parent_dir) if name.endswith(("/", os.path.sep)): name = os.path.join(name, str(self.ds)) new_name = validate_image_name(name, suffix) if new_name == name: # somehow mypy thinks we may not have a plots attr yet, hence we turn it off here for v in self.plots.values(): # type: ignore out_name = v.save(name, mpl_kwargs) names.append(out_name) return names name = new_name prefix, suffix = os.path.splitext(name) if hasattr(self.data_source, "axis"): axis = self.ds.coordinates.axis_name.get(self.data_source.axis, "") else: axis = None weight = None plot_type = self._plot_type if plot_type in ["Projection", "OffAxisProjection"]: weight = self.data_source.weight_field if weight is not None: weight = weight[1].replace(" ", "_") if "Cutting" in self.data_source.__class__.__name__: plot_type = "OffAxisSlice" # somehow mypy thinks we may not have a plots attr yet, hence we turn it off here for k, v in self.plots.items(): # type: ignore if isinstance(k, tuple): k = k[1] if plot_type is None: # implemented this check to make mypy happy, because we can't use str.join # with PlotContainer._plot_type = None raise TypeError( f"{self.__class__} is missing a _plot_type value (str)") name_elements = [prefix, plot_type] if axis: name_elements.append(axis) name_elements.append(k.replace(" ", "_")) if weight: name_elements.append(weight) name = "_".join(name_elements) + suffix names.append(v.save(name, mpl_kwargs)) return names
def save_arbor(self, filename="arbor", fields=None, trees=None, max_file_size=524288): r""" Save the arbor to a file. The saved arbor can be re-loaded as an arbor. Parameters ---------- filename : optional, string Output file keyword. Main header file will be named <filename>/<filename>.h5. Default: "arbor". fields : optional, list of strings The fields to be saved. If not given, all fields will be saved. Returns ------- filename : string The filename of the saved arbor. Examples -------- >>> import ytree >>> a = ytree.load("rockstar_halos/trees/tree_0_0_0.dat") >>> fn = a.save_arbor() >>> # reload it >>> a2 = ytree.load(fn) """ if trees is None: all_trees = True trees = self.trees roots = trees else: all_trees = False # assemble unique tree roots for getting fields trees = np.asarray(trees) roots = [] root_uids = [] for tree in trees: if tree.root == -1: my_root = tree else: my_root = tree.root if my_root.uid not in root_uids: roots.append(my_root) root_uids.append(my_root.uid) roots = np.array(roots) del root_uids if fields in [None, "all"]: # If a field has an alias, get that instead. fields = [] for field in self.field_list + self.analysis_field_list: fields.extend(self.field_info[field].get("aliases", [field])) else: fields.extend([f for f in ["uid", "desc_uid"] if f not in fields]) ds = {} for attr in ["hubble_constant", "omega_matter", "omega_lambda"]: if hasattr(self, attr): ds[attr] = getattr(self, attr) extra_attrs = { "box_size": self.box_size, "arbor_type": "YTreeArbor", "unit_registry_json": self.unit_registry.to_json() } self._node_io_loop(self._setup_tree, root_nodes=roots, pbar="Setting up trees") self._root_io.get_fields(self, fields=fields) # determine file layout nn = 0 # node count nt = 0 # tree count nnodes = [] ntrees = [] tree_size = np.array([tree.tree_size for tree in trees]) for ts in tree_size: nn += ts nt += 1 if nn > max_file_size: nnodes.append(nn - ts) ntrees.append(nt - 1) nn = ts nt = 1 if nn > 0: nnodes.append(nn) ntrees.append(nt) nfiles = len(nnodes) nnodes = np.array(nnodes) ntrees = np.array(ntrees) tree_end_index = ntrees.cumsum() tree_start_index = tree_end_index - ntrees # write header file fieldnames = [field.replace("/", "_") for field in fields] myfi = {} rdata = {} rtypes = {} for field, fieldname in zip(fields, fieldnames): fi = self.field_info[field] myfi[fieldname] = \ dict((key, fi[key]) for key in ["units", "description"] if key in fi) if all_trees: rdata[fieldname] = self._root_field_data[field] else: rdata[fieldname] = self.arr([t[field] for t in trees]) rtypes[fieldname] = "data" # all saved trees will be roots if not all_trees: rdata["desc_uid"][:] = -1 extra_attrs["field_info"] = json.dumps(myfi) extra_attrs["total_files"] = nfiles extra_attrs["total_trees"] = trees.size extra_attrs["total_nodes"] = tree_size.sum() hdata = { "tree_start_index": tree_start_index, "tree_end_index": tree_end_index, "tree_size": ntrees } hdata.update(rdata) htypes = dict((f, "index") for f in hdata) htypes.update(rtypes) ensure_dir(filename) header_filename = os.path.join(filename, "%s.h5" % filename) save_as_dataset(ds, header_filename, hdata, field_types=htypes, extra_attrs=extra_attrs) # write data files ftypes = dict((f, "data") for f in fieldnames) for i in range(nfiles): my_nodes = trees[tree_start_index[i]:tree_end_index[i]] self._node_io_loop(self._node_io.get_fields, pbar="Getting fields [%d/%d]" % (i + 1, nfiles), root_nodes=my_nodes, fields=fields, root_only=False) fdata = dict((field, np.empty(nnodes[i])) for field in fieldnames) my_tree_size = tree_size[tree_start_index[i]:tree_end_index[i]] my_tree_end = my_tree_size.cumsum() my_tree_start = my_tree_end - my_tree_size pbar = get_pbar("Creating field arrays [%d/%d]" % (i + 1, nfiles), len(fields) * nnodes[i]) c = 0 for field, fieldname in zip(fields, fieldnames): for di, node in enumerate(my_nodes): if node.is_root: ndata = node._tree_field_data[field] else: ndata = node["tree", field] if field == "desc_uid": # make sure it's a root when loaded ndata[0] = -1 fdata[fieldname][my_tree_start[di]:my_tree_end[di]] = ndata c += my_tree_size[di] pbar.update(c) pbar.finish() fdata["tree_start_index"] = my_tree_start fdata["tree_end_index"] = my_tree_end fdata["tree_size"] = my_tree_size for ft in ["tree_start_index", "tree_end_index", "tree_size"]: ftypes[ft] = "index" my_filename = os.path.join(filename, "%s_%04d.h5" % (filename, i)) save_as_dataset({}, my_filename, fdata, field_types=ftypes) return header_filename